xref: /aosp_15_r20/external/pytorch/test/nn/test_pooling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2import itertools
3import math
4import operator
5import os
6import random
7import subprocess
8import sys
9import unittest
10from functools import partial, reduce
11from itertools import repeat
12
13import torch
14import torch.nn as nn
15import torch.nn.functional as F
16from torch import inf, nan
17from torch.autograd import gradcheck, gradgradcheck
18from torch.testing import make_tensor
19from torch.testing._internal.common_cuda import TEST_CUDA
20from torch.testing._internal.common_device_type import (
21    dtypes,
22    dtypesIfCUDA,
23    expectedFailureMeta,
24    instantiate_device_type_tests,
25    largeTensorTest,
26    onlyCPU,
27    onlyCUDA,
28    onlyNativeDeviceTypes,
29    skipCUDAIfRocm,
30    TEST_WITH_ROCM,
31)
32from torch.testing._internal.common_dtype import floating_types_and
33from torch.testing._internal.common_nn import (
34    _test_bfloat16_ops,
35    _test_module_empty_input,
36    NNTestCase,
37)
38from torch.testing._internal.common_utils import (
39    gcIfJetson,
40    instantiate_parametrized_tests,
41    parametrize as parametrize_test,
42    run_tests,
43    set_default_dtype,
44    skipIfMps,
45    skipIfTorchDynamo,
46    slowTest,
47    subtest,
48    TEST_WITH_UBSAN,
49    TestCase,
50)
51
52
53class TestAvgPool(TestCase):
54    def _sum_pool2d(self, x, kernel_size):
55        windows = torch.nn.functional.unfold(
56            x, kernel_size=kernel_size, stride=kernel_size
57        )
58        return torch.sum(windows, dim=1)
59
60    def _sum_pool3d(self, x, kernel_size):
61        # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
62        h = kernel_size[0]
63        splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
64        # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
65        splited_x = [
66            self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:])
67            for t in splited_x
68        ]
69        joined_x = torch.cat(splited_x)
70        return joined_x.view(1, joined_x.numel())
71
72    def _avg_pool2d(self, x, kernel_size):
73        size = reduce(operator.mul, kernel_size)
74        return self._sum_pool2d(x, kernel_size) / size
75
76    def _avg_pool3d(self, x, kernel_size):
77        size = reduce(operator.mul, kernel_size)
78        return self._sum_pool3d(x, kernel_size) / size
79
80    def test_doubletensor_avg_pool2d(self):
81        n, m = 5, 8
82        input = torch.rand(1, 1, n, m, dtype=torch.double)
83        for i in range(1, n + 1):
84            for j in range(1, m + 1):
85                actual = torch.nn.functional.avg_pool2d(input[0], (i, j))
86                actual = actual.view(1, actual.numel())
87                expected = self._avg_pool2d(input, (i, j))
88                self.assertEqual(actual, expected, rtol=0, atol=1e-5)
89
90    def test_doubletensor_avg_pool2d_with_divisor(self):
91        n, m = 3, 3
92        input = torch.rand(1, 1, n, m, dtype=torch.double)
93        for i in range(1, n + 1):
94            for j in range(1, m + 1):
95                for divisor in [1, 7, i * j]:
96                    actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
97                    actual = actual.view(1, actual.numel())
98                    expected = self._sum_pool2d(input, (i, j)) / divisor
99                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)
100
101    def test_doubletensor_avg_pool3d(self):
102        h, w, d = 5, 6, 7
103        input = torch.rand(h, w, d, dtype=torch.double)
104        for i in range(1, h + 1):
105            for j in range(1, w + 1):
106                for k in range(1, d + 1):
107                    actual = torch.nn.functional.avg_pool3d(
108                        input.unsqueeze(0), (i, j, k)
109                    )
110                    actual = actual.view(1, actual.numel())
111                    expected = self._avg_pool3d(input, (i, j, k))
112                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)
113
114    def test_doubletensor_avg_pool3d_with_divisor(self):
115        h, w, d = 6, 5, 7
116        input = torch.rand(h, w, d, dtype=torch.double)
117        for i in range(1, h + 1):
118            for j in range(1, w + 1):
119                for k in range(1, d + 1):
120                    for divisor in [1, 7, i * j]:
121                        actual = torch.nn.functional.avg_pool3d(
122                            input.unsqueeze(0), (i, j, k), divisor_override=divisor
123                        )
124                        actual = actual.view(1, actual.numel())
125                        expected = self._sum_pool3d(input, (i, j, k)) / divisor
126                        self.assertEqual(actual, expected, rtol=0, atol=1e-5)
127
128    def test_avg_pool1d_ceil_mode(self):
129        # Regression test for gh-36977
130        x = 10 * torch.randn((1, 16, 4))
131        y = torch.nn.functional.avg_pool1d(
132            x, ceil_mode=True, count_include_pad=True, kernel_size=1, stride=2
133        )
134        self.assertTrue(not torch.isnan(y).any())
135
136        if TEST_CUDA:
137            y = torch.nn.functional.avg_pool1d(
138                x.to("cuda"),
139                ceil_mode=True,
140                count_include_pad=True,
141                kernel_size=1,
142                stride=2,
143            )
144            self.assertTrue(not torch.isnan(y).any())
145
146    def test_avg_pool2d_ceil_mode(self):
147        # Regression test for gh-36977
148        x = 10 * torch.randn((1, 16, 4, 4))
149        y = torch.nn.functional.avg_pool2d(
150            x,
151            ceil_mode=True,
152            count_include_pad=True,
153            kernel_size=(1, 2),
154            padding=(0, 1),
155            stride=2,
156        )
157        self.assertTrue(not torch.isnan(y).any())
158
159        if TEST_CUDA:
160            y = torch.nn.functional.avg_pool2d(
161                x.to("cuda"),
162                ceil_mode=True,
163                count_include_pad=True,
164                kernel_size=(1, 2),
165                padding=(0, 1),
166                stride=2,
167            )
168            self.assertTrue(not torch.isnan(y).any())
169
170    def test_avg_pool3d_ceil_mode(self):
171        # Regression test for gh-36977
172        x = 10 * torch.randn((1, 16, 4, 4, 4))
173        y = torch.nn.functional.avg_pool3d(
174            x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2, 3), stride=2
175        )
176        self.assertTrue(not torch.isnan(y).any())
177
178        if TEST_CUDA:
179            y = torch.nn.functional.avg_pool3d(
180                x.to("cuda"),
181                ceil_mode=True,
182                count_include_pad=True,
183                kernel_size=(1, 2, 3),
184                stride=2,
185            )
186            self.assertTrue(not torch.isnan(y).any())
187
188
189class TestPoolingNN(NNTestCase):
190    _do_cuda_memory_leak_check = True
191    _do_cuda_non_default_stream = True
192
193    def test_adaptive_pooling_size_none(self):
194        for numel in (2, 3):
195            for pool_type in ("Max", "Avg"):
196                cls_name = f"Adaptive{pool_type}Pool{numel}d"
197                module_cls = getattr(nn, cls_name)
198                output_size = (2,) * (numel - 1) + (None,)
199                module = module_cls(output_size)
200
201                input = torch.randn((4,) * (numel + 1))
202                output = module(input)
203                self.assertEqual(output.size(), (4,) + (2,) * (numel - 1) + (4,))
204
205    @unittest.skipIf(TEST_WITH_UBSAN, "signed integer overflow error with UBSAN")
206    def test_adaptive_pooling_size_overflow(self):
207        # 0x0x3fffffffffffffff * 2 * 2 = 0xfffffffffffffffc = -4 as int64_t
208        # Tensor::numel() return int64_t, so following check that negative allocs are correctly handled
209        self.assertRaises(
210            RuntimeError,
211            lambda: torch.nn.AdaptiveMaxPool1d(0x3FFFFFFFFFFFFFFF)(
212                torch.empty([2, 2, 2])
213            ),
214        )
215
216    def test_adaptive_pooling_avg_nhwc(self):
217        device_list = ["cpu"]
218        if TEST_CUDA:
219            device_list.append("cuda")
220
221        for device in device_list:
222            input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device)
223            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
224            grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device)
225            pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
226
227            ref_input = input.detach().clone().contiguous().requires_grad_(True)
228            ref_grad = grad.detach().clone().contiguous()
229            ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
230
231            out = pool(input)
232            out.backward(grad)
233            ref_out = ref_pool(ref_input)
234            ref_out.backward(ref_grad)
235
236            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
237            self.assertTrue(ref_out.is_contiguous())
238            self.assertEqual(out, ref_out)
239            self.assertEqual(input.grad, ref_input.grad)
240
241    def test_adaptive_pooling_avg_nhwc_non_contiguous(self):
242        device_list = ["cpu"]
243        if TEST_CUDA:
244            device_list.append("cuda")
245
246        for device in device_list:
247            input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device)
248            input = input.contiguous(memory_format=torch.channels_last)
249            input = input[:, ::2, :, :].requires_grad_()
250            grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device)
251            grad = grad[:, ::2, :, :]
252            pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
253
254            ref_input = input.detach().clone().contiguous().requires_grad_(True)
255            ref_grad = grad.detach().clone().contiguous()
256            ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
257
258            out = pool(input)
259            out.backward(grad)
260            ref_out = ref_pool(ref_input)
261            ref_out.backward(ref_grad)
262
263            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
264            self.assertTrue(ref_out.is_contiguous())
265            self.assertEqual(out, ref_out)
266            self.assertEqual(input.grad, ref_input.grad)
267
268    def test_adaptive_pooling_lower_precision(self):
269        def _test_adaptive_pooling_lower_precision(
270            self, device, dtype, mod, memory_format
271        ):
272            input = torch.randint(1, 10, (3, 19, 8, 8), dtype=torch.float32)
273            input = input.to(device).to(memory_format=memory_format).requires_grad_()
274            pool = mod((7, 7)).to(device)
275
276            input2 = input.detach().clone().to(dtype=dtype).requires_grad_(True)
277
278            out = pool(input)
279            out.sum().backward()
280            out2 = pool(input2)
281            out2.sum().backward()
282
283            self.assertTrue(out2.is_contiguous(memory_format=memory_format))
284            self.assertEqual(out2.dtype, dtype)
285            self.assertEqual(input2.grad.dtype, dtype)
286            self.assertEqual(out, out2.float(), atol=0.1, rtol=0)
287            self.assertEqual(input.grad, input2.grad.float(), atol=0.1, rtol=0)
288
289        device_list = ["cpu"]
290        for device in device_list:
291            for dtype in [torch.bfloat16, torch.float16]:
292                _test_adaptive_pooling_lower_precision(
293                    self,
294                    device,
295                    dtype,
296                    torch.nn.AdaptiveAvgPool2d,
297                    torch.contiguous_format,
298                )
299                _test_adaptive_pooling_lower_precision(
300                    self, device, dtype, torch.nn.AdaptiveAvgPool2d, torch.channels_last
301                )
302                _test_adaptive_pooling_lower_precision(
303                    self,
304                    device,
305                    dtype,
306                    torch.nn.AdaptiveMaxPool2d,
307                    torch.contiguous_format,
308                )
309                _test_adaptive_pooling_lower_precision(
310                    self, device, dtype, torch.nn.AdaptiveMaxPool2d, torch.channels_last
311                )
312
313    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
314    @largeTensorTest("12GB", device="cuda")
315    def test_adaptive_pooling_avg_nhwc_launch_config_backward(self):
316        input = torch.randint(
317            1, 10, (1, 32, 2**17 + 1, 32), dtype=torch.float32, device="cuda"
318        )
319        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
320        grad = torch.randint(1, 10, (1, 32, 10, 32), dtype=torch.float32, device="cuda")
321
322        pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda()
323
324        ref_input = input.detach().clone().contiguous().requires_grad_(True)
325        ref_grad = grad.detach().clone().contiguous()
326        ref_pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda()
327
328        out = pool(input)
329        out.backward(grad)
330        ref_out = ref_pool(ref_input)
331        ref_out.backward(ref_grad)
332
333        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
334        self.assertTrue(ref_out.is_contiguous())
335        self.assertEqual(out, ref_out)
336        self.assertEqual(input.grad, ref_input.grad)
337
338    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
339    @largeTensorTest("12GB", device="cuda")
340    def test_adaptive_pooling_avg_nhwc_launch_config_forward(self):
341        input = torch.randint(
342            1, 10, (1, 32, 16, 16), dtype=torch.float32, device="cuda"
343        )
344        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
345        pool = torch.nn.AdaptiveAvgPool2d((2**17 + 1, 32)).cuda()
346
347        ref_input = input.detach().clone().contiguous().requires_grad_(True)
348        ref_pool = torch.nn.AdaptiveAvgPool2d((2**17 + 1, 32)).cuda()
349
350        out = pool(input)
351        ref_out = ref_pool(ref_input)
352
353        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
354        self.assertTrue(ref_out.is_contiguous())
355        self.assertEqual(out, ref_out)
356
357    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
358    def test_adaptive_avg_pooling_overflow(self):
359        input = torch.randint(
360            -256, 256, (20, 32, 256, 256), dtype=torch.half, device="cuda"
361        )
362        avg_pool = torch.nn.AdaptiveAvgPool2d((2, 2))
363        out = avg_pool(input)
364        self.assertFalse(torch.isinf(out).any())
365        self.assertFalse(torch.isnan(out).any())
366
367    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
368    def test_adaptive_avg_pooling_nhwc_overflow(self):
369        input = torch.randint(
370            -256, 256, (20, 32, 256, 256), dtype=torch.half, device="cuda"
371        )
372        input = input.contiguous(memory_format=torch.channels_last)
373        avg_pool = torch.nn.AdaptiveAvgPool2d((2, 2))
374        out = avg_pool(input)
375        self.assertFalse(torch.isinf(out).any())
376        self.assertFalse(torch.isnan(out).any())
377
378    def test_MaxUnpool2d_output_size(self):
379        m = nn.MaxPool2d(3, stride=2, return_indices=True)
380        mu = nn.MaxUnpool2d(3, stride=2)
381        big_t = torch.rand(1, 1, 6, 6)
382        big_t[0][0][4][4] = 100
383        output_big, indices_big = m(big_t)
384        self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
385
386        small_t = torch.rand(1, 1, 5, 5)
387        for i in range(0, 4, 2):
388            for j in range(0, 4, 2):
389                small_t[:, :, i, j] = 100
390        output_small, indices_small = m(small_t)
391        for h in range(3, 10):
392            for w in range(3, 10):
393                if 4 <= h <= 6 and 4 <= w <= 6:
394                    size = (h, w)
395                    if h == 6:
396                        size = (1, 1) + size
397
398                    mu(output_small, indices_small, output_size=size)
399                else:
400                    self.assertRaises(
401                        ValueError, lambda: mu(output_small, indices_small, (h, w))
402                    )
403
404    def test_max_unpool2d_nhwc_cpu(self):
405        input = torch.randn(2, 10, 9, 9).float().cpu()
406        input = input.contiguous(memory_format=torch.channels_last)
407        ref_input = input.clone().contiguous()
408
409        pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
410        ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
411
412        out, ind = pool(input)
413        ref_out, ref_ind = ref_pool(ref_input)
414        out.requires_grad_()
415        ref_out.requires_grad_()
416
417        unpool = nn.MaxUnpool2d(3, stride=2).cpu()
418        ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu()
419
420        upout = unpool(out, ind)
421        ref_upout = ref_unpool(ref_out, ref_ind)
422
423        grad = torch.randn(upout.size()).float().cpu()
424        grad = grad.contiguous(memory_format=torch.channels_last)
425        ref_grad = grad.clone().contiguous()
426
427        upout.backward(grad)
428        ref_upout.backward(ref_grad)
429
430        self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last))
431        self.assertTrue(ref_upout.is_contiguous())
432        self.assertTrue(torch.allclose(upout, ref_upout))
433        self.assertTrue(torch.allclose(out.grad, ref_out.grad))
434
435    def test_max_unpool(self):
436        with set_default_dtype(torch.double):
437            # Test 1D
438            output, indices = F.max_pool1d(
439                torch.randn([1, 1, 4]), 2, stride=2, return_indices=True
440            )
441            self.assertEqual(
442                F.max_unpool1d(output, indices, 2),
443                F.max_unpool1d(output, indices, 2, stride=2),
444            )
445
446            # Test list / tuple passed as argument to max_unpool1d
447            input = torch.randn([1, 1, 5], requires_grad=True)
448            output, indices = F.max_pool1d(input, 2, stride=2, return_indices=True)
449            self.assertEqual(
450                F.max_unpool1d(output, indices, 2, stride=2, output_size=input.shape),
451                F.max_unpool1d(output, indices, 2, stride=2, output_size=input.size()),
452            )
453            gradcheck(F.max_unpool1d, (output, indices, 2), check_forward_ad=True)
454
455            # Test 2D
456            output, indices = F.max_pool2d(
457                torch.randn([1, 1, 4, 4], requires_grad=True),
458                2,
459                stride=2,
460                return_indices=True,
461            )
462            self.assertEqual(
463                F.max_unpool2d(output, indices, 2),
464                F.max_unpool2d(output, indices, 2, stride=2),
465            )
466            gradcheck(F.max_unpool2d, (output, indices, 2), check_forward_ad=True)
467
468            # Test 3D
469            output, indices = F.max_pool3d(
470                torch.randn([4, 4, 4, 4, 4], requires_grad=True),
471                2,
472                stride=2,
473                return_indices=True,
474            )
475            self.assertEqual(
476                F.max_unpool3d(output, indices, 2),
477                F.max_unpool3d(output, indices, 2, stride=2),
478            )
479            gradcheck(F.max_unpool3d, (output, indices, 2), check_forward_ad=True)
480
481    def test_max_unpool3d_input_check(self):
482        x = torch.ones(1, 3, 1, 1, 1)
483        with self.assertRaises(RuntimeError):
484            F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1])
485
486    def test_quantized_max_pool1d_empty_kernel(self):
487        # This used to segfault when called with an empty kernel
488        # see https://github.com/pytorch/pytorch/issues/116323
489        base = torch.randn(1)
490        temp_tensor = torch.quantize_per_tensor(base, 0.1, 10, torch.quint2x4)
491        with self.assertRaises(RuntimeError):
492            torch.quantized_max_pool1d(temp_tensor, [])
493
494
495class TestPoolingNNDeviceType(NNTestCase):
496    @onlyNativeDeviceTypes
497    @dtypes(torch.float, torch.double)
498    def test_adaptive_pooling_zero_batch(self, dtype, device):
499        inp = torch.ones(0, 10, dtype=dtype, device=device)
500        mod = torch.nn.AdaptiveAvgPool1d(5).to(device)
501        _test_module_empty_input(self, mod, inp, check_size=False)
502
503        inp = torch.ones(0, 10, 10, dtype=dtype, device=device)
504        mod = torch.nn.AdaptiveAvgPool2d((5, 5)).to(device)
505        _test_module_empty_input(self, mod, inp, check_size=False)
506
507        inp = torch.ones(0, 10, 10, 10, dtype=dtype, device=device)
508        mod = torch.nn.AdaptiveAvgPool3d((5, 5, 5)).to(device)
509        _test_module_empty_input(self, mod, inp, check_size=False)
510
511    # The tests are used to verify the functions raises errors for backward propagation
512    # when output_size = 0, in adaptive_{avg, max}_pool and its variants.
513    # These tests are explicitly written because ErrorInputs does not support backward calls
514    # Issue: https://github.com/pytorch/pytorch/issues/78868
515    @onlyNativeDeviceTypes
516    @dtypes(torch.float32, torch.float64)
517    @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16)
518    def test_adaptive_pooling_empty_output_size(self, dtype, device):
519        error_msg = (
520            "Expected grad_output to have non-zero size for non-batch dimensions"
521        )
522
523        make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
524        input = make_arg((1, 64, 10, 9))
525        output_size = 0
526
527        fns = (
528            nn.functional.adaptive_avg_pool2d,
529            nn.functional.adaptive_avg_pool3d,
530            nn.functional.adaptive_max_pool2d,
531            nn.functional.adaptive_max_pool3d,
532        )
533
534        for fn in fns:
535            with self.assertRaisesRegex(RuntimeError, error_msg):
536                fn(input, output_size).sum().backward()
537
538        fns2 = (
539            nn.functional.adaptive_avg_pool1d,
540            nn.functional.adaptive_max_pool1d,
541        )
542        input2 = make_arg((1, 64))
543
544        for fn in fns2:
545            with self.assertRaisesRegex(RuntimeError, error_msg):
546                fn(input2, output_size).sum().backward()
547
548    @onlyNativeDeviceTypes
549    def test_FractionalMaxPool2d_zero_batch(self, device):
550        mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
551        inp = torch.ones(0, 16, 50, 32, device=device)
552        _test_module_empty_input(self, mod, inp, check_size=False)
553
554        with self.assertRaisesRegex(RuntimeError, "Expected input"):
555            inp = torch.randn(1, 0, 50, 32, device=device)
556            mod(inp)
557
558    @onlyNativeDeviceTypes
559    def test_FractionalMaxPool3d_zero_batch(self, device):
560        mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device)
561        inp = torch.ones(0, 16, 50, 32, 32, device=device)
562        _test_module_empty_input(self, mod, inp, check_size=False)
563
564        with self.assertRaisesRegex(RuntimeError, "Expected input"):
565            inp = torch.randn(1, 0, 50, 32, 32, device=device)
566            mod(inp)
567
568    @onlyNativeDeviceTypes
569    def test_FractionalMaxPool2d_zero_out_size(self, device):
570        mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1])
571        inp = torch.rand([16, 50, 32, 32], device=device)
572        out = mod(inp)
573        self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device))
574
575    @onlyNativeDeviceTypes
576    def test_FractionalMaxPool3d_zero_out_size(self, device):
577        mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1])
578        inp = torch.rand([16, 50, 32, 32], device=device)
579        out = mod(inp)
580        self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device))
581
582    @onlyNativeDeviceTypes
583    def test_FractionalMaxPool2d_zero_samples(self, device):
584        samples = torch.rand([0, 16, 2], device=device)
585        mod = nn.FractionalMaxPool2d(
586            [2, 2], output_size=[1, 1], _random_samples=samples
587        )
588        inp = torch.randn([0, 16, 32, 32], device=device)
589        out = mod(inp)
590        self.assertEqual(out, torch.empty((0, 16, 1, 1), device=device))
591
592        inp1 = torch.randn([1, 16, 32, 32], device=device)
593        with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"):
594            out1 = mod(inp1)
595
596    @onlyNativeDeviceTypes
597    def test_FractionalMaxPool3d_zero_samples(self, device):
598        samples = torch.rand([0, 16, 3], device=device)
599        mod = nn.FractionalMaxPool3d(
600            [3, 2, 2], output_size=[1, 1, 1], _random_samples=samples
601        )
602        inp = torch.randn([0, 16, 50, 32, 32], device=device)
603        out = mod(inp)
604        self.assertEqual(out, torch.empty((0, 16, 1, 1, 1), device=device))
605
606        inp1 = torch.randn([1, 16, 50, 32, 32], device=device)
607        with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"):
608            out1 = mod(inp1)
609
610    @onlyNativeDeviceTypes
611    def test_MaxPool_zero_batch_dim(self, device):
612        inp = torch.randn(0, 16, 50, device=device)
613        mod = torch.nn.MaxPool1d(3, stride=2).to(device)
614        _test_module_empty_input(self, mod, inp, check_size=False)
615
616        # 1D is supposed to be okay with 0 numel() inputs so dont test
617        # error raising for that case.
618
619        inp = torch.randn(0, 16, 50, 32, device=device)
620        mod = torch.nn.MaxPool2d(3, stride=2).to(device)
621        _test_module_empty_input(self, mod, inp, check_size=False)
622
623        with self.assertRaisesRegex(RuntimeError, "Expected"):
624            inp = torch.randn(1, 0, 50, 32, device=device)
625            mod(inp)
626
627        inp = torch.ones(0, 16, 50, 44, 31, device=device)
628        mod = torch.nn.MaxPool3d(3, stride=2).to(device)
629        _test_module_empty_input(self, mod, inp, check_size=False)
630
631        with self.assertRaisesRegex(RuntimeError, "Expected"):
632            inp = torch.ones(1, 0, 50, 44, 31, device=device)
633            mod(inp)
634
635    @onlyNativeDeviceTypes
636    def test_MaxUnpool_zero_batch_dim(self, device):
637        pool = torch.nn.MaxPool1d(2, stride=2, return_indices=True).to(device)
638        unpool = torch.nn.MaxUnpool1d(2, stride=2).to(device)
639        inp = torch.randn(0, 10, 10, requires_grad=True, device=device)
640        output, indices = pool(inp)
641        output.requires_grad_(True)
642        unpool_out = unpool(output, indices)
643        unpool_out.sum().backward()
644
645        self.assertEqual(inp.grad, torch.zeros_like(inp))
646        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
647
648        pool = torch.nn.MaxPool2d(2, stride=2, return_indices=True).to(device)
649        unpool = torch.nn.MaxUnpool2d(2, stride=2).to(device)
650        inp = torch.randn(0, 10, 10, 10, requires_grad=True, device=device)
651        output, indices = pool(inp)
652        unpool_out = unpool(output, indices)
653        unpool_out.sum().backward()
654
655        self.assertEqual(inp.grad, torch.zeros_like(inp))
656        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
657
658        pool = torch.nn.MaxPool3d(2, stride=2, return_indices=True).to(device)
659        unpool = torch.nn.MaxUnpool3d(2, stride=2).to(device)
660        inp = torch.randn(0, 10, 10, 10, 10, requires_grad=True, device=device)
661        output, indices = pool(inp)
662        output.requires_grad_(True)
663        unpool_out = unpool(output, indices)
664        unpool_out.sum().backward()
665
666        self.assertEqual(inp.grad, torch.zeros_like(inp))
667        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
668
669    @slowTest
670    @onlyNativeDeviceTypes
671    @skipCUDAIfRocm
672    @parametrize_test(
673        "module_name,module_size,output_size,test_index,should_error",
674        [
675            # Some tests are failing in trunk https://github.com/pytorch/pytorch/issues/103854
676            subtest(
677                ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), -1, True),
678                name="case1",
679            ),
680            subtest(
681                ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), 2 * 2 * 4 * 5, True),
682                name="case2",
683            ),
684            subtest(
685                ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), (2 * 2 * 4 * 5) - 1, False),
686                name="case3",
687            ),
688            subtest(
689                ("MaxUnpool2d", (2, 3), (2, 1, 4, 2), 2 * 3 * 4 * 2, True),
690                name="case4",
691            ),
692            subtest(
693                ("MaxUnpool2d", (2, 3), (2, 1, 4, 2), (2 * 3 * 4 * 2) - 1, False),
694                name="case5",
695            ),
696            subtest(
697                ("MaxUnpool3d", (2, 2, 2), (1, 3, 4, 5), -1, True),
698                name="case6",
699            ),
700            subtest(
701                ("MaxUnpool3d", (2, 2, 2), (1, 3, 4, 5), 2 * 2 * 2 * 3 * 4 * 5, True),
702                name="case7",
703            ),
704            subtest(
705                (
706                    "MaxUnpool3d",
707                    (2, 2, 2),
708                    (1, 3, 4, 5),
709                    (2 * 2 * 2 * 3 * 4 * 5) - 1,
710                    False,
711                ),
712                name="case8",
713            ),
714            subtest(
715                ("MaxUnpool3d", (2, 2, 2), (2, 3, 4, 1), 2 * 2 * 2 * 3 * 4 * 1, True),
716                name="case9",
717            ),
718            subtest(
719                (
720                    "MaxUnpool3d",
721                    (2, 2, 2),
722                    (2, 3, 4, 1),
723                    (2 * 2 * 2 * 3 * 4 * 1) - 1,
724                    False,
725                ),
726                name="case10",
727            ),
728        ],
729    )
730    def test_MaxUnpool_index_errors(
731        self, device, module_name, module_size, output_size, test_index, should_error
732    ):
733        # NOTE: CUDA tests need to be run in a subprocess because they cause device asserts
734        if torch.device(device).type == "cuda":
735            error_msgs = {
736                "MaxUnpool2d": r"Assertion `maxind >= 0 && maxind < outputImageSize` failed",
737                "MaxUnpool3d": r"Assertion `index >= 0 && index < outputImageSize` failed",
738            }
739
740            script = f"""
741import torch
742unpool = torch.nn.{module_name}({module_size}).to('{device}')
743output = torch.rand({output_size}, dtype=torch.float32, device='{device}')
744indices = torch.zeros({output_size}, dtype=torch.int64, device='{device}')
745indices.flatten()[0] = {test_index}
746unpool(output, indices)
747torch.cuda.synchronize()
748"""
749            p = subprocess.run(
750                [sys.executable, "-c", script],
751                cwd=os.path.dirname(os.path.realpath(__file__)),
752                capture_output=True,
753                text=True,
754            )
755
756            output = p.stdout + "\n" + p.stderr
757
758            error_msg = error_msgs[module_name]
759
760            if should_error:
761                self.assertIn(error_msg, output, "The expected error was not found")
762            else:
763                self.assertNotIn("Error", output, "Should not have produced an error")
764        else:
765            module_class = getattr(torch.nn, module_name)
766            unpool = module_class(module_size).to(device)
767            output = torch.rand(output_size, dtype=torch.float32, device=device)
768            indices = torch.zeros(output_size, dtype=torch.int64, device=device)
769            indices.flatten()[0] = test_index
770
771            if should_error:
772                with self.assertRaisesRegex(
773                    RuntimeError, r"Found an invalid max index:"
774                ):
775                    unpool(output, indices)
776            else:
777                unpool(output, indices)
778
779    @onlyNativeDeviceTypes
780    def test_AdaptiveMaxPool_zero_batch_dim(self, device):
781        inp = torch.randn(0, 16, 50, device=device)
782        mod = torch.nn.AdaptiveMaxPool1d(3).to(device)
783        _test_module_empty_input(self, mod, inp, check_size=False)
784
785        with self.assertRaisesRegex(RuntimeError, "Expected"):
786            inp = torch.randn(1, 0, 50, device=device)
787            mod(inp)
788
789        inp = torch.randn(0, 16, 50, 32, device=device)
790        mod = torch.nn.AdaptiveMaxPool2d(3).to(device)
791        _test_module_empty_input(self, mod, inp, check_size=False)
792
793        with self.assertRaisesRegex(RuntimeError, "Expected"):
794            inp = torch.randn(1, 0, 50, 32, device=device)
795            mod(inp)
796
797        inp = torch.ones(0, 16, 50, 44, 31, device=device)
798        mod = torch.nn.AdaptiveMaxPool3d(3).to(device)
799        _test_module_empty_input(self, mod, inp, check_size=False)
800
801        with self.assertRaisesRegex(RuntimeError, "Expected"):
802            inp = torch.ones(1, 0, 50, 44, 31, device=device)
803            mod(inp)
804
805    @onlyNativeDeviceTypes
806    def test_AvgPool2d_empty(self, device):
807        avgpool = torch.nn.AvgPool2d(3, stride=2).to(device)
808        inp = torch.randn(0, 16, 20, 32, device=device)
809        _test_module_empty_input(self, avgpool, inp, check_size=False)
810
811        clast_inp = torch.randn(0, 16, 20, 32, device=device).contiguous(
812            memory_format=torch.channels_last
813        )
814        _test_module_empty_input(self, avgpool, clast_inp, check_size=False)
815
816        # test with empty non-batch input
817        with self.assertRaisesRegex(RuntimeError, "3D or 4D"):
818            inp = torch.randn(16, 0, 20, 32, device=device)
819            avgpool(inp)
820
821    def test_pooling_shape(self, device):
822        """Test the output shape calculation for pooling functions"""
823
824        # Checks output shape against expected for 1D, 2D and 3D
825        def check(expected_out_shape, sizes, *args, **kwargs):
826            for kernel in ["max", "avg"]:
827                for i in [1, 2, 3]:
828                    if hasattr(torch.nn.functional, f"{kernel}_pool{i}d"):
829                        op = getattr(torch.nn.functional, f"{kernel}_pool{i}d")
830                        t = torch.randn(sizes[: i + 2], device=device)
831                        self.assertEqual(
832                            op(t, *args, **kwargs).shape, expected_out_shape[: i + 2]
833                        )
834
835        check(
836            (1, 1, 3, 3, 4),
837            (1, 1, 5, 6, 7),
838            kernel_size=1,
839            stride=2,
840            padding=0,
841            ceil_mode=True,
842        )
843        check(
844            (1, 1, 2, 3, 3),
845            (1, 1, 3, 4, 5),
846            kernel_size=2,
847            stride=2,
848            padding=1,
849            ceil_mode=False,
850        )
851        check(
852            (1, 1, 2, 3, 3),
853            (1, 1, 3, 4, 5),
854            kernel_size=2,
855            stride=2,
856            padding=1,
857            ceil_mode=True,
858        )
859
860        # Test case from issue https://github.com/pytorch/pytorch/issues/45357
861        x = torch.randn(1, 1, 6, 7, device=device)
862        y = torch.nn.functional.max_pool2d(
863            x, 1, stride=(2, 2), padding=0, ceil_mode=True
864        )
865        self.assertEqual(y.size(), (1, 1, 3, 4))
866
867    @onlyNativeDeviceTypes  # TODO: fix on XLA
868    def test_adaptive_avg_pool2d_output_size_one(self, device):
869        def helper(size, memory_format):
870            x = torch.randint(
871                1, 10, size, dtype=torch.float, device=device, requires_grad=True
872            )
873            if memory_format == "non_contiguous":
874                x = x[::2, ::2, ::2, ::2]
875            else:
876                x = x.to(memory_format=memory_format)
877
878            net = torch.nn.AdaptiveAvgPool2d((1, 1))
879            out = net(x)
880            ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
881
882            out.sum().backward()  # make sure it doesn't crash
883
884            self.assertEqual(out, ref_out)
885            if memory_format == torch.channels_last:
886                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
887                c = out.size(1)
888                self.assertEqual(out.stride(), [c, 1, c, c])
889            else:
890                self.assertTrue(out.is_contiguous())
891                c = out.size(1)
892                self.assertEqual(out.stride(), [c, 1, 1, 1])
893
894        for mf in (torch.contiguous_format, torch.channels_last, "non_contiguous"):
895            helper((2, 3, 6, 6), mf)
896
897    @onlyNativeDeviceTypes
898    def test_adaptive_avg_pool3d_output_size_one(self, device):
899        x = torch.randn(
900            (2, 3, 6, 6, 6), dtype=torch.float, device=device, requires_grad=True
901        )
902
903        net = torch.nn.AdaptiveAvgPool3d(1)
904        out = net(x)
905        ref_out = x.contiguous().mean((-1, -2, -3)).view(out.shape)
906
907        out.sum().backward()  # make sure it doesn't crash
908
909        self.assertEqual(out, ref_out)
910        self.assertTrue(out.is_contiguous())
911        c = out.size(1)
912        self.assertEqual(out.stride(), [c, 1, 1, 1, 1])
913
914    @expectedFailureMeta  # Runtime Error not raised for meta
915    @onlyNativeDeviceTypes
916    @dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long)
917    def test_adaptive_pooling_no_suppot_input(self, device, dtype):
918        for numel in (2, 3):
919            for pool_type in ("Max", "Avg"):
920                cls_name = f"Adaptive{pool_type}Pool{numel}d"
921                module_cls = getattr(nn, cls_name)
922                output_size = (2,) * numel
923                module = module_cls(output_size)
924                input = torch.randn((4,) * (numel + 1), device=device).to(dtype)
925                with self.assertRaisesRegex(RuntimeError, "not implemented"):
926                    output = module(input)
927
928    @onlyNativeDeviceTypes
929    @gcIfJetson
930    @dtypes(torch.float, torch.double)
931    @dtypesIfCUDA(torch.half, torch.float, torch.double)
932    def test_avg_pool2d_nhwc(self, device, dtype):
933        def helper(
934            n,
935            c,
936            h,
937            w,
938            kernel_size,
939            stride=None,
940            count_include_pad=True,
941            divisor_override=None,
942            padding=0,
943        ):
944            if stride is None:
945                stride = kernel_size
946            input = torch.randn(n, c, h, w, dtype=dtype, device=device)
947            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
948            grad = torch.randn(
949                n,
950                c,
951                (h - kernel_size) // stride + 1,
952                (w - kernel_size) // stride + 1,
953                dtype=dtype,
954                device=device,
955            )
956            pool = torch.nn.AvgPool2d(
957                kernel_size,
958                stride=stride,
959                count_include_pad=count_include_pad,
960                divisor_override=divisor_override,
961            ).to(device)
962
963            ref_input = input.detach().clone().contiguous().requires_grad_(True)
964            ref_grad = grad.detach().clone().contiguous()
965            ref_pool = torch.nn.AvgPool2d(
966                kernel_size,
967                stride=stride,
968                count_include_pad=count_include_pad,
969                divisor_override=divisor_override,
970            ).to(device)
971
972            out = pool(input)
973            out.backward(grad)
974            ref_out = ref_pool(ref_input)
975            ref_out.backward(ref_grad)
976
977            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
978            self.assertTrue(ref_out.is_contiguous())
979            self.assertEqual(out, ref_out)
980            self.assertEqual(input.grad, ref_input.grad)
981
982        helper(4, 8, 8, 8, 3)
983        helper(4, 8, 8, 8, 3, count_include_pad=False, padding=1)
984        helper(4, 8, 8, 8, 3, count_include_pad=False, padding=2, stride=2)
985        helper(4, 8, 8, 8, 3, divisor_override=42)
986        helper(4, 8, 8, 8, 7)
987        # ROCm 16GB MI25 hits OOM error. Clear caching allocator prior to running large subtest.
988        if TEST_WITH_ROCM and "cuda" in device:
989            torch.cuda.empty_cache()
990        helper(200, 512, 28, 28, 2)
991        helper(4, 8, 7, 7, 3, stride=1)
992        helper(4, 8, 7, 7, 3, padding=2, stride=1)
993        helper(10, 512, 31, 31, 3, stride=2)
994        helper(1, 129, 8, 8, 3, stride=2)
995
996    @onlyCPU
997    @dtypes(torch.float, torch.double)
998    def test_max_pool1d_corner_cases(self, device, dtype):
999        def check(x, args, expected):
1000            model = torch.nn.MaxPool1d(*args)
1001            if isinstance(x, list):
1002                x = torch.tensor(x, device=device, dtype=dtype)
1003                expected = torch.tensor(expected, device=device, dtype=dtype)
1004            self.assertEqual(model(x), expected)
1005
1006        # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
1007        check([[1]], (1, None, 0, 1, False, False), [[1]])
1008        check([[1]], (2, None, 1, 2, False, False), [[float("-inf")]])
1009        check(
1010            [[1], [1]],
1011            (2, None, 1, 2, False, False),
1012            [[float("-inf")], [float("-inf")]],
1013        )
1014        check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]])
1015        check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]])
1016
1017    @onlyCPU
1018    @dtypes(torch.float, torch.double)
1019    @skipIfTorchDynamo("OOMs https://github.com/pytorch/pytorch/issues/111320")
1020    def test_max_pool1d(self, device, dtype):
1021        # FIXME For now compare against max_pool1d with indices
1022        def check(x, *args, **kwargs):
1023            model = torch.nn.MaxPool1d(*args, **kwargs)
1024            ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True)
1025            self.assertEqual(model(x), ref_model(x)[0])
1026
1027        sizes = [random.sample(range(8, 128), 3) for _ in range(3)]
1028        kernel_sizes = random.sample(range(1, 5), 3)
1029        strides = random.sample(range(1, 5), 3)
1030        dilations = random.sample(range(1, 5), 3)
1031        ceil_modes = [True, False]
1032
1033        for size, kernel_size, stride, dilation, ceil_mode in itertools.product(
1034            sizes, kernel_sizes, strides, dilations, ceil_modes
1035        ):
1036            padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1)
1037            check(
1038                torch.randn(size, device=device, dtype=dtype),
1039                kernel_size,
1040                stride,
1041                padding,
1042                dilation,
1043                ceil_mode=ceil_mode,
1044            )
1045
1046        # Non-contiguous test
1047        tensor = torch.randn(5, 151, 33, device=device, dtype=dtype)[::2, ::3, ::2]
1048        check(tensor, 3, 2, 1, 2, ceil_mode=True)
1049        check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True)
1050
1051    @onlyCUDA
1052    @gcIfJetson
1053    def test_max_pool2d(self, device):
1054        def helper(n, c, h, w, ks):
1055            x = torch.randn(
1056                n, c, h, w, device="cuda", dtype=torch.float, requires_grad=True
1057            )
1058            ref_x = x.detach().clone().cpu().requires_grad_()
1059
1060            pool = torch.nn.MaxPool2d(kernel_size=ks)
1061
1062            y = pool(x)
1063            ref_y = pool(ref_x)
1064
1065            y.sum().backward()
1066            ref_y.sum().backward()
1067
1068            self.assertEqual(y, ref_y)
1069            self.assertEqual(x.grad, ref_x.grad)
1070
1071        helper(2, 8, 4, 4, ks=2)
1072        helper(1, 100000, 32, 32, ks=4)
1073        helper(1, 100000, 1, 4, ks=(1, 4))  # test for max_pool1d
1074
1075    @onlyNativeDeviceTypes
1076    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1077    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1078    @gcIfJetson
1079    def test_max_pool2d_nhwc(self, device, dtype):
1080        def helper(n, c, h, w, kernel_size, stride=None):
1081            if stride is None:
1082                stride = kernel_size
1083            input = torch.randn(n, c, h, w, dtype=dtype, device=device)
1084            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
1085            grad = torch.randn(
1086                n,
1087                c,
1088                (h - kernel_size) // stride + 1,
1089                (w - kernel_size) // stride + 1,
1090                dtype=dtype,
1091                device=device,
1092            )
1093            pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(
1094                device
1095            )
1096
1097            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1098            ref_grad = grad.detach().clone().contiguous()
1099            ref_pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(
1100                device
1101            )
1102
1103            out, ind = pool(input)
1104            out.backward(grad)
1105            ref_out, ref_ind = ref_pool(ref_input)
1106            ref_out.backward(ref_grad)
1107
1108            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
1109            self.assertTrue(ref_out.is_contiguous())
1110            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
1111            self.assertTrue(ref_ind.is_contiguous())
1112            self.assertEqual(out, ref_out)
1113            self.assertEqual(ind, ref_ind)
1114            self.assertEqual(input.grad, ref_input.grad)
1115
1116        helper(4, 8, 8, 8, 7)
1117        helper(200, 512, 28, 28, 2)
1118        helper(4, 8, 7, 7, 3, stride=1)
1119        helper(10, 512, 31, 31, 3, stride=2)
1120        helper(1, 129, 8, 8, 3, stride=2)
1121
1122    @onlyCPU
1123    @dtypes(torch.int32, torch.int64)
1124    def test_max_pool2d_corner_cases(self, device, dtype):
1125        def check(x, args, expected, memory_format):
1126            model = torch.nn.MaxPool2d(*args)
1127            if isinstance(x, list):
1128                x = torch.tensor(x, device=device, dtype=dtype).to(
1129                    memory_format=memory_format
1130                )
1131                expected = torch.tensor(expected, device=device, dtype=dtype).to(
1132                    memory_format=memory_format
1133                )
1134            self.assertEqual(model(x), expected)
1135
1136        # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
1137        check(
1138            [[[[-1, -2], [-3, -4]]]],
1139            (2, 2, 1, 2, False, True),
1140            [[[[-4, -4], [-4, -4]]]],
1141            torch.contiguous_format,
1142        )
1143        check(
1144            [[[[-1, -2], [-3, -4]]]],
1145            (2, 2, 1, 2, False, True),
1146            [[[[-4, -4], [-4, -4]]]],
1147            torch.channels_last,
1148        )
1149
1150    @onlyNativeDeviceTypes
1151    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1152    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1153    @gcIfJetson
1154    def test_max_pool3d_ndhwc(self, device, dtype):
1155        def helper(n, c, h, w, d, kernel_size, stride=None):
1156            batch = n
1157            if not batch:
1158                batch = 1
1159            input = torch.randn(batch, c, d, h, w, dtype=dtype, device=device)
1160            input = input.contiguous(
1161                memory_format=torch.channels_last_3d
1162            ).requires_grad_()
1163            if not n:
1164                input = input.squeeze(0).detach().clone().requires_grad_()
1165            if isinstance(kernel_size, int):
1166                kernel_size = [kernel_size] * 3
1167            if stride is None:
1168                stride = kernel_size
1169            elif isinstance(stride, int):
1170                stride = [stride] * 3
1171            grad = torch.randn(
1172                batch,
1173                c,
1174                (d - kernel_size[0]) // stride[0] + 1,
1175                (h - kernel_size[1]) // stride[1] + 1,
1176                (w - kernel_size[2]) // stride[2] + 1,
1177                dtype=dtype,
1178                device=device,
1179            )
1180            grad = grad.contiguous(memory_format=torch.channels_last_3d)
1181            if not n:
1182                grad = grad.squeeze(0)
1183            pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(
1184                device
1185            )
1186
1187            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1188            ref_grad = grad.detach().clone().contiguous()
1189            ref_pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(
1190                device
1191            )
1192            out, ind = pool(input)
1193            out.backward(grad)
1194            ref_out, ref_ind = ref_pool(ref_input)
1195            ref_out.backward(ref_grad)
1196
1197            if len(out.shape) == 4:
1198                self.assertTrue(
1199                    out.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d)
1200                )
1201            else:
1202                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
1203            self.assertTrue(ref_out.is_contiguous())
1204            if len(ind.shape) == 4:
1205                self.assertTrue(
1206                    ind.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d)
1207                )
1208            else:
1209                self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last_3d))
1210            self.assertTrue(ref_ind.is_contiguous())
1211            self.assertEqual(out, ref_out)
1212            self.assertEqual(ind, ref_ind)
1213            if dtype == torch.half:
1214                self.assertEqual(input.grad, ref_input.grad, atol=0.05, rtol=0.01)
1215            else:
1216                self.assertEqual(input.grad, ref_input.grad)
1217
1218        helper(4, 8, 8, 8, 8, 7)
1219        helper(4, 8, 8, 8, 8, (5, 6, 7))
1220        helper(1, 8, 8, 8, 8, (5, 6, 7))
1221        helper(0, 6, 12, 13, 14, (5, 6, 7))
1222        helper(4, 8, 7, 7, 7, 3, stride=1)
1223        helper(10, 128, 19, 19, 19, 3, stride=2)
1224        helper(10, 128, 19, 19, 19, (1, 2, 3), stride=2)
1225        helper(1, 128, 19, 19, 19, (1, 2, 3), stride=2)
1226        helper(0, 128, 19, 19, 19, (1, 2, 3), stride=2)
1227        helper(1, 79, 4, 4, 4, 3, stride=2)
1228        helper(0, 79, 4, 4, 4, 3, stride=2)
1229
1230    @onlyCPU
1231    @dtypes(torch.half, torch.bfloat16)
1232    def test_max_pool_bfloat16_half(self, device, dtype):
1233        def helper(shape, kernel_size, stride, memory_format, dtype):
1234            input = torch.randn(shape, dtype=dtype, device=device)
1235            input = input.to(memory_format=memory_format).requires_grad_()
1236            if len(shape) == 4:
1237                pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(
1238                    device
1239                )
1240            else:
1241                pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(
1242                    device
1243                )
1244
1245            input2 = input.detach().clone().float().requires_grad_(True)
1246
1247            out, ind = pool(input)
1248            out.sum().backward()
1249            out2, ind2 = pool(input2)
1250            out2.sum().backward()
1251
1252            self.assertTrue(out.is_contiguous(memory_format=memory_format))
1253            self.assertEqual(out.dtype, dtype)
1254            self.assertEqual(input.grad.dtype, dtype)
1255            self.assertEqual(out, out2.to(dtype=dtype))
1256            self.assertEqual(ind, ind2)
1257            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
1258
1259        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
1260        helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
1261        helper((1, 19, 20, 10), 8, 2, torch.contiguous_format, dtype)
1262        helper((1, 19, 20, 10), 8, 2, torch.channels_last, dtype)
1263        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
1264        helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
1265        helper((1, 19, 10, 10, 10), 8, 2, torch.contiguous_format, dtype)
1266        helper((1, 19, 10, 9, 14), 8, 2, torch.channels_last_3d, dtype)
1267        helper((4, 10, 3, 8, 8), 3, 1, torch.contiguous_format, dtype)
1268        helper((4, 10, 8, 8, 8), 7, 1, torch.channels_last_3d, dtype)
1269
1270    @onlyCUDA
1271    @gcIfJetson
1272    def test_max_pool2d_indices(self, device):
1273        def helper(n, c, h, w, ks):
1274            if n is None:
1275                x = torch.randn(
1276                    c, h, w, device="cuda", dtype=torch.float, requires_grad=True
1277                )
1278            else:
1279                x = torch.randn(
1280                    n, c, h, w, device="cuda", dtype=torch.float, requires_grad=True
1281                )
1282
1283            ref_x = x.detach().clone().cpu().requires_grad_()
1284
1285            pool = torch.nn.MaxPool2d(kernel_size=ks, return_indices=True)
1286
1287            y, idx = pool(x)
1288            ref_y, ref_idx = pool(ref_x)
1289
1290            y.sum().backward()
1291            ref_y.sum().backward()
1292
1293            self.assertEqual(y, ref_y)
1294            self.assertEqual(
1295                idx, ref_idx
1296            )  # assertEqual implicitly compares shape for tensors
1297            self.assertEqual(x.grad, ref_x.grad)
1298
1299        helper(2, 8, 4, 4, ks=2)
1300        helper(None, 3, 50, 50, ks=5)
1301
1302    @onlyCPU
1303    @dtypes(torch.half, torch.bfloat16)
1304    def test_avg_pool2d_reduced_floating(self, device, dtype):
1305        def helper(n, c, h, w, kernel_size, stride, memory_format):
1306            input = torch.randn(n, c, h, w, dtype=torch.float32, device=device).to(
1307                dtype=dtype
1308            )
1309            input = input.to(memory_format=memory_format).requires_grad_()
1310            pool = torch.nn.AvgPool2d(kernel_size, stride).to(device)
1311
1312            input2 = input.detach().clone().float().requires_grad_(True)
1313
1314            out = pool(input)
1315            out.sum().backward()
1316            out2 = pool(input2)
1317            out2.sum().backward()
1318
1319            self.assertTrue(out.is_contiguous(memory_format=memory_format))
1320            self.assertEqual(out.dtype, dtype)
1321            self.assertEqual(input.grad.dtype, dtype)
1322            self.assertEqual(out, out2.to(dtype=dtype))
1323            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
1324
1325        helper(4, 30, 8, 8, 7, 1, torch.contiguous_format)
1326        helper(4, 65, 8, 8, 7, 1, torch.channels_last)
1327        helper(1, 19, 20, 10, 8, 2, torch.contiguous_format)
1328        helper(1, 19, 20, 10, 8, 2, torch.channels_last)
1329
1330    @dtypes(torch.float, torch.double)
1331    def test_adaptive_pooling_max_nhwc(self, device, dtype):
1332        def helper(input_size, output_plane_size, contig):
1333            n_plane_dims = len(output_plane_size)
1334            mod = (
1335                torch.nn.AdaptiveMaxPool2d
1336                if n_plane_dims == 2
1337                else torch.nn.AdaptiveMaxPool3d
1338            )
1339            channels_last = (
1340                torch.channels_last if n_plane_dims == 2 else torch.channels_last_3d
1341            )
1342            output_size = input_size[:2] + output_plane_size
1343            input = torch.randint(1, 10, input_size, device=device, dtype=dtype)
1344            input = input.contiguous(memory_format=channels_last)
1345            grad = torch.randint(1, 10, output_size, device=device, dtype=dtype)
1346            grad = grad.contiguous(memory_format=channels_last)
1347            if not contig:
1348                input = input[:, ::2]
1349                grad = grad[:, ::2]
1350            input.requires_grad_(True)
1351            pool = mod(output_plane_size, return_indices=True).to(device)
1352
1353            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1354            ref_grad = grad.detach().clone().contiguous()
1355            ref_pool = mod(output_plane_size, return_indices=True).to(device)
1356
1357            out, ind = pool(input)
1358            out.backward(grad)
1359            ref_out, ref_ind = ref_pool(ref_input)
1360            ref_out.backward(ref_grad)
1361
1362            # channels_last_3d case does not return channels_last_3d outputs
1363            if n_plane_dims == 2:
1364                self.assertTrue(out.is_contiguous(memory_format=channels_last))
1365                self.assertTrue(ind.is_contiguous(memory_format=channels_last))
1366            self.assertTrue(ref_out.is_contiguous())
1367            self.assertTrue(ref_ind.is_contiguous())
1368            self.assertEqual(out, ref_out)
1369            self.assertEqual(ind, ref_ind)
1370            self.assertEqual(input.grad, ref_input.grad)
1371
1372        for contig in [True, False]:
1373            helper((4, 8, 10, 10), (7, 7), contig)
1374            helper((4, 8, 9, 14), (5, 8), contig)
1375            helper((4, 8, 11, 11), (1, 1), contig)
1376            helper((2, 1, 3, 3), (1, 1), contig)
1377            helper((4, 8, 10, 10, 10), (7, 7, 7), contig)
1378            helper((4, 8, 11, 11, 11), (1, 1, 1), contig)
1379            helper((2, 1, 3, 3, 3), (1, 1, 1), contig)
1380
1381    @dtypes(torch.float, torch.double)
1382    def test_pooling_max_nhwc(self, device, dtype):
1383        def helper(n, c, h, w, kernel_size, stride, padding, dilation, contig, device):
1384            output_height = math.floor(
1385                (h + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1)
1386                / stride[0]
1387                + 1
1388            )
1389            output_width = math.floor(
1390                (w + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1)
1391                / stride[1]
1392                + 1
1393            )
1394
1395            input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
1396            input = input.contiguous(memory_format=torch.channels_last)
1397            grad = torch.randint(
1398                1, 10, (n, c, output_height, output_width), device=device, dtype=dtype
1399            )
1400            grad = grad.contiguous(memory_format=torch.channels_last)
1401            if not contig:
1402                input = input[:, ::2, :, :]
1403                grad = grad[:, ::2, :, :]
1404            input.requires_grad_(True)
1405            pool = torch.nn.MaxPool2d(
1406                kernel_size,
1407                stride,
1408                padding,
1409                dilation,
1410                return_indices=True,
1411                ceil_mode=False,
1412            )
1413
1414            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1415            ref_grad = grad.detach().clone().contiguous()
1416            ref_pool = torch.nn.MaxPool2d(
1417                kernel_size,
1418                stride,
1419                padding,
1420                dilation,
1421                return_indices=True,
1422                ceil_mode=False,
1423            ).to(device)
1424
1425            out, ind = pool(input)
1426            out.backward(grad)
1427            ref_out, ref_ind = ref_pool(ref_input)
1428            ref_out.backward(ref_grad)
1429
1430            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
1431            self.assertTrue(ref_out.is_contiguous())
1432            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
1433            self.assertTrue(ref_ind.is_contiguous())
1434            self.assertEqual(out, ref_out)
1435            self.assertEqual(ind, ref_ind)
1436            self.assertEqual(input.grad, ref_input.grad)
1437
1438        for contig in [True, False]:
1439            helper(4, 8, 10, 10, (2, 2), (1, 1), (1, 1), (2, 2), contig, device)
1440            helper(4, 8, 9, 14, (2, 2), (1, 1), (1, 1), (2, 2), contig, device)
1441            helper(4, 8, 11, 11, (4, 4), (2, 2), (2, 2), (2, 2), contig, device)
1442
1443    @onlyCUDA
1444    def test_pool3d_size_one_feature_dim(self, device):
1445        # Tests crazy strides for feature dim of size 1
1446        x = torch.randn(7, 1, 5, 3, 2, device=device)
1447        strange_strides = [30, 1234, 6, 2, 1]
1448        y = x.as_strided(x.size(), strange_strides)
1449        x = x.cpu().as_strided(x.size(), strange_strides)
1450
1451        to_test = {
1452            "max_pool3d": lambda t: F.max_pool3d(t, (5, 1, 1), stride=(5, 1, 1)),
1453            "avg_pool3d": lambda t: F.avg_pool3d(t, (5, 1, 1), stride=(5, 1, 1)),
1454        }
1455
1456        for test, fn in to_test.items():
1457            # Should not crash
1458            out_y = fn(y)
1459            out_x = fn(x)
1460            self.assertEqual(out_y, out_x.to(device), msg=test)
1461
1462    @onlyCUDA
1463    @largeTensorTest("18GB")
1464    @largeTensorTest("180GB", "cpu")
1465    def test_pool3d_large_size_int64(self, device):
1466        # See https://github.com/pytorch/pytorch/issues/52822
1467        x = torch.randn(
1468            70, 32, 100, 100, 100, dtype=torch.half, device=device, requires_grad=True
1469        )
1470        y = torch.nn.functional.max_pool3d(x, 5)
1471        g = torch.randn_like(y, dtype=torch.half)
1472        torch.cuda.synchronize()
1473        y.backward(g)
1474        torch.cuda.synchronize()
1475
1476        ref_x = x.detach().cpu().float()  # max_pool3d_cpu is not implemented for half
1477        ref_x.requires_grad = True
1478        ref_g = g.cpu().float()
1479        ref_y = torch.nn.functional.max_pool3d(ref_x, 5)
1480        ref_y.backward(ref_g)
1481
1482        self.assertEqual(y, ref_y, exact_dtype=False)
1483        self.assertEqual(x.grad, ref_x.grad, exact_dtype=False)
1484
1485    @onlyCUDA
1486    def test_AvgPool3d_backward_after_cat_dim1_device(self, device):
1487        # x has to have batch_size 1 to test contiguous checks
1488        x = torch.randn(1, 3, 4, 4, 4, device=device, requires_grad=True)
1489        y = F.avg_pool3d(x, kernel_size=3, padding=1, stride=2)
1490
1491        grad = torch.randn(y.size(), device=device)
1492        # increase the stride in dimension 0. the tensor is still contiguous because size[0] is 1
1493        stride = list(grad.stride())
1494        stride[0] = stride[0] * 2
1495        grad.set_(grad.storage(), 0, grad.size(), stride)
1496        assert grad.is_contiguous()
1497
1498        y.backward(grad)
1499
1500    def _test_maxpool_indices(
1501        self, num_dim, adaptive=False, device="cpu", dtype=torch.float
1502    ):
1503        def expected_indices(dim, dtype):
1504            if dim == 1:
1505                return torch.tensor([1, 3], dtype=dtype).repeat(2, 2, 1)
1506            if dim == 2:
1507                return torch.tensor([[5, 7], [13, 15]], dtype=dtype).repeat(2, 2, 1, 1)
1508
1509        def expected_grad(dim, dtype):
1510            if dim == 1:
1511                return torch.tensor([0, 1, 0, 1], dtype=dtype).repeat(2, 2, 1)
1512            grad = expected_grad(dim - 1, dtype=dtype)
1513            zero = torch.zeros(grad.size(), dtype=dtype)
1514            return torch.stack((zero, grad, zero, grad), 2)
1515
1516        def expected_output(dim, dtype):
1517            if dim == 1:
1518                return torch.arange(2, 17, 2, dtype=dtype).view(2, 2, 2)
1519            if dim == 2:
1520                col = torch.arange(6, 63, 8, dtype=dtype)
1521                return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)
1522
1523        if adaptive:
1524            cls_name = "AdaptiveMaxPool{}d".format(num_dim)  # noqa: UP032
1525        else:
1526            # FIXME(#105716): Test fails when using f-string
1527            cls_name = "MaxPool{}d".format(num_dim)  # noqa: UP032
1528        module_cls = getattr(nn, cls_name)
1529        module = module_cls(2, return_indices=True).to(device, dtype=dtype)
1530        numel = 4 ** (num_dim + 1)
1531        input = (
1532            torch.arange(1, numel + 1)
1533            .view(2, 2, *repeat(4, num_dim))
1534            .to(device, dtype=dtype)
1535        )
1536        input_var = input.clone().detach().requires_grad_()
1537
1538        # Check forward
1539        output, indices = module(input_var)
1540        if num_dim != 3:
1541            expected_indices = expected_indices(num_dim, dtype=indices.data.dtype)
1542            expected_output = expected_output(num_dim, dtype=output.data.dtype)
1543            self.assertEqual(indices.dim(), input.dim())
1544            self.assertEqual(indices.data.squeeze(), expected_indices)
1545            self.assertEqual(output.data.squeeze(), expected_output)
1546        self.assertTrue(output.requires_grad)
1547        self.assertFalse(indices.requires_grad)
1548
1549        # Make sure backward works
1550        grad_output = torch.ones(output.size(), device=device, dtype=dtype)
1551        output.backward(grad_output, retain_graph=True)
1552        expected_grad = expected_grad(num_dim, dtype=input_var.grad.data.dtype)
1553        self.assertEqual(input_var.grad.data, expected_grad.view_as(input))
1554
1555        # Make sure backward after changing indices will result in an error
1556        indices.add_(1)
1557        self.assertRaises(RuntimeError, lambda: output.backward(grad_output))
1558
1559        # Make sure -Infinity is handled correctly
1560        t = torch.tensor([[[float("-inf")]]])
1561        m = nn.MaxPool1d(kernel_size=1, return_indices=True)
1562        output, indices = m(t)
1563        self.assertEqual(output[0, 0, 0], float("-inf"))
1564        self.assertEqual(indices[0, 0, 0], 0)
1565
1566        t = torch.tensor([[[float("-inf")]]])
1567        m = nn.MaxPool2d(kernel_size=1, return_indices=True)
1568        output, indices = m(t)
1569        self.assertEqual(output[0, 0, 0], float("-inf"))
1570        self.assertEqual(indices[0, 0, 0], 0)
1571
1572        t = torch.tensor([[[[float("-inf")]]]])
1573        m = nn.MaxPool3d(kernel_size=1, return_indices=True)
1574        output, indices = m(t)
1575        self.assertEqual(output[0, 0, 0, 0], float("-inf"))
1576        self.assertEqual(indices[0, 0, 0, 0], 0)
1577
1578    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1579    @dtypes(torch.float)
1580    def test_MaxPool1d_indices(self, device, dtype):
1581        self._test_maxpool_indices(1, device=device, dtype=dtype)
1582
1583    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1584    @dtypes(torch.float)
1585    def test_MaxPool2d_indices(self, device, dtype):
1586        self._test_maxpool_indices(2, device=device, dtype=dtype)
1587
1588    @skipIfMps
1589    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1590    @dtypes(torch.float)
1591    def test_MaxPool3d_indices(self, device, dtype):
1592        self._test_maxpool_indices(3, device=device, dtype=dtype)
1593
1594    @skipIfMps
1595    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1596    @dtypes(torch.float)
1597    def test_AdaptiveMaxPool1d_indices(self, device, dtype):
1598        self._test_maxpool_indices(1, adaptive=True, device=device, dtype=dtype)
1599
1600    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1601    @skipIfMps
1602    @dtypes(torch.float)
1603    def test_AdaptiveMaxPool2d_indices(self, device, dtype):
1604        self._test_maxpool_indices(2, adaptive=True, device=device, dtype=dtype)
1605
1606    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1607    @skipIfMps
1608    @dtypes(torch.float)
1609    def test_AdaptiveMaxPool3d_indices(self, device, dtype):
1610        self._test_maxpool_indices(3, adaptive=True, device=device, dtype=dtype)
1611
1612    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1613    @skipIfMps
1614    @dtypes(torch.float)
1615    def test_maxpool_indices_no_batch_dim(self, device, dtype):
1616        """Check that indices with no batch dim is consistent with a single batch."""
1617        max_pool_cases = [
1618            (
1619                nn.MaxPool1d(3, return_indices=True),
1620                torch.randn(3, 5, device=device, dtype=dtype),
1621            ),
1622            (
1623                nn.MaxPool2d(3, return_indices=True),
1624                torch.randn(3, 5, 6, device=device, dtype=dtype),
1625            ),
1626            (
1627                nn.MaxPool3d(3, return_indices=True),
1628                torch.randn(3, 5, 6, 7, device=device, dtype=dtype),
1629            ),
1630            (
1631                nn.AdaptiveMaxPool1d(3, return_indices=True),
1632                torch.randn(3, 5, device=device, dtype=dtype),
1633            ),
1634            (
1635                nn.AdaptiveMaxPool2d(3, return_indices=True),
1636                torch.randn(3, 5, 6, device=device, dtype=dtype),
1637            ),
1638            (
1639                nn.AdaptiveMaxPool3d(3, return_indices=True),
1640                torch.randn(3, 5, 6, 7, device=device, dtype=dtype),
1641            ),
1642        ]
1643
1644        for module, input in max_pool_cases:
1645            _, indices_no_batch = module(input)
1646            _, indicies_single_batch = module(input.unsqueeze(0))
1647            self.assertEqual(indices_no_batch, indicies_single_batch.squeeze(0))
1648
1649    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1650    @dtypes(torch.float)
1651    @onlyNativeDeviceTypes  # TODO: Fails on XLA
1652    @gcIfJetson
1653    def test_max_pool_nan_inf(self, device, dtype):
1654        for adaptive in ["", "adaptive_"]:
1655            for num_dim in [1, 2, 3]:
1656                fn_name = f"{adaptive}max_pool{num_dim}d"
1657                fn = getattr(F, fn_name)
1658
1659                x = torch.full(
1660                    [1, 1] + num_dim * [3],
1661                    nan,
1662                    device=device,
1663                    dtype=dtype,
1664                    requires_grad=True,
1665                )
1666                res = fn(x, 1 if adaptive else 3)
1667                res.backward(torch.randn_like(res))
1668                self.assertTrue(math.isnan(res.item()))
1669                x.requires_grad_(False)
1670                res = fn(x, 1 if adaptive else 3)
1671                self.assertTrue(math.isnan(res.item()))
1672
1673                x2 = torch.full(
1674                    [1, 1] + num_dim * [3],
1675                    -inf,
1676                    device=device,
1677                    dtype=dtype,
1678                    requires_grad=True,
1679                )
1680                res2 = fn(x2, 1 if adaptive else 3)
1681                res2.backward(torch.randn_like(res2))
1682                self.assertTrue(math.isinf(res2.item()))
1683                x2.requires_grad_(False)
1684                res2 = fn(x2, 1 if adaptive else 3)
1685                self.assertTrue(math.isinf(res2.item()))
1686
1687    @expectedFailureMeta  # RuntimeError: Unrecognized tensor type ID: Meta
1688    @onlyNativeDeviceTypes
1689    def test_fractional_max_pool2d(self, device):
1690        with set_default_dtype(torch.double):
1691            x = torch.randn(1, 2, 7, 7, requires_grad=True, device=device)
1692            samples = x.new(1, 2, 2).uniform_()
1693
1694            def func(x):
1695                return F.fractional_max_pool2d(
1696                    x, (2, 2), output_size=(3, 3), _random_samples=samples
1697                )
1698
1699            self.assertEqual(func(x).shape, (1, 2, 3, 3))
1700            gradcheck(func, [x])
1701            gradgradcheck(func, [x])
1702
1703            x = torch.randn(2, 7, 7, requires_grad=True, device=device)
1704            self.assertEqual(func(x).shape, (2, 3, 3))
1705            if self.device_type != "cuda":
1706                # Reference: https://github.com/pytorch/pytorch/issues/52427
1707                # Raises -> RuntimeError: TensorAccessor expected 4 dims but tensor has 3
1708                # on CUDA in gradcheck
1709                gradcheck(func, [x])
1710                gradgradcheck(func, [x])
1711
1712            for kernel_size in [(), (1,)]:
1713                with self.assertRaisesRegex(RuntimeError, "kernel_size must either"):
1714                    # Incorrect kernel_size
1715                    F.fractional_max_pool2d(
1716                        x,
1717                        kernel_size=kernel_size,
1718                        output_size=(3, 3),
1719                        _random_samples=samples,
1720                    )
1721
1722            err_large_msg = "too large relative to input "
1723            err_out_size_msg = "output_size must either"
1724            for output_size, msg in [
1725                ((9, 3), err_large_msg + "height"),
1726                ((3, 9), err_large_msg + "width"),
1727                ((3,), err_out_size_msg),
1728                ((), err_out_size_msg),
1729            ]:
1730                with self.assertRaisesRegex(RuntimeError, msg):
1731                    # Incorrect output_size
1732                    F.fractional_max_pool2d(
1733                        x, (2, 2), output_size=output_size, _random_samples=samples
1734                    )
1735
1736    @expectedFailureMeta  # RuntimeError: Unrecognized tensor type ID: Meta
1737    @onlyNativeDeviceTypes
1738    def test_fractional_max_pool3d(self, device):
1739        with set_default_dtype(torch.double):
1740            x = torch.randn(1, 2, 7, 7, 7, requires_grad=True, device=device)
1741            samples = x.new(1, 2, 3).uniform_()
1742
1743            def func(x):
1744                return F.fractional_max_pool3d(
1745                    x, (2, 2, 2), output_size=(3, 3, 3), _random_samples=samples
1746                )
1747
1748            self.assertEqual(func(x).shape, (1, 2, 3, 3, 3))
1749            gradcheck(func, [x])
1750            gradgradcheck(func, [x])
1751
1752            x = torch.randn(2, 7, 7, 7, requires_grad=True, device=device)
1753            self.assertEqual(func(x).shape, (2, 3, 3, 3))
1754            gradcheck(func, [x])
1755            gradgradcheck(func, [x])
1756
1757            for kernel_size in [(), (1,), (1, 1)]:
1758                with self.assertRaisesRegex(RuntimeError, "kernel_size must either"):
1759                    # Incorrect kernel_size
1760                    F.fractional_max_pool3d(
1761                        x,
1762                        kernel_size=kernel_size,
1763                        output_size=(3, 3, 3),
1764                        _random_samples=samples,
1765                    )
1766
1767            err_large_msg = "too large relative to input "
1768            err_out_size_msg = "output_size must either"
1769            for output_size, msg in [
1770                ((9, 3, 3), err_large_msg + "time"),
1771                ((3, 9, 3), err_large_msg + "height"),
1772                ((3, 3, 9), err_large_msg + "width"),
1773                ((3, 3), err_out_size_msg),
1774                ((3,), err_out_size_msg),
1775                ((), err_out_size_msg),
1776            ]:
1777                with self.assertRaisesRegex(RuntimeError, msg):
1778                    # Incorrect output_size
1779                    F.fractional_max_pool3d(
1780                        x, (2, 2, 2), output_size=output_size, _random_samples=samples
1781                    )
1782
1783    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1784    @dtypes(torch.float)
1785    @onlyNativeDeviceTypes  # TODO: Fails on XLA
1786    def test_fractional_max_pool_nan_inf(self, device, dtype):
1787        for num_dim in [2, 3]:
1788            fn_name = f"FractionalMaxPool{num_dim}d"
1789            fn = getattr(nn, fn_name)(kernel_size=2, output_size=1)
1790            x = torch.full(
1791                [1, 1] + num_dim * [3],
1792                nan,
1793                device=device,
1794                dtype=dtype,
1795                requires_grad=True,
1796            )
1797            res = fn(x)
1798            res.backward(torch.randn_like(res))
1799            self.assertTrue(math.isnan(res.item()))
1800
1801            x2 = torch.full(
1802                [1, 1] + num_dim * [3],
1803                -inf,
1804                device=device,
1805                dtype=dtype,
1806                requires_grad=True,
1807            )
1808            res2 = fn(x2)
1809            res2.backward(torch.randn_like(res2))
1810            self.assertTrue(math.isinf(res2.item()))
1811
1812    @onlyNativeDeviceTypes  # TODO: RuntimeError message different on XLA
1813    def test_pooling_zero_stride(self, device):
1814        for op in ("max", "avg"):
1815            for num_dim in [1, 2, 3]:
1816                fn_name = f"{op}_pool{num_dim}d"
1817                fn = getattr(F, fn_name)
1818                x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float)
1819                self.assertRaisesRegex(
1820                    RuntimeError,
1821                    r"stride should not be zero|stride must be greater than zero",
1822                    lambda: fn(x, kernel_size=2, stride=0),
1823                )
1824
1825                fn_module_name = f"{op.title()}Pool{num_dim}d"
1826                fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0)
1827                self.assertRaisesRegex(
1828                    RuntimeError,
1829                    r"stride should not be zero|stride must be greater than zero",
1830                    lambda: fn_module(x),
1831                )
1832
1833    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1834    @skipIfMps
1835    @dtypes(torch.float)
1836    def test_pool_large_size(self, device, dtype):
1837        for op in ("max", "avg"):
1838            for num_dim in [1, 2, 3]:
1839                fn_name = f"{op}_pool{num_dim}d"
1840                fn = getattr(F, fn_name)
1841                # 16777217 is the smallest integer not expressible in float32
1842                x = torch.ones(
1843                    [1, 1, 16777217] + (num_dim - 1) * [1], device=device, dtype=dtype
1844                )
1845                res = fn(x, 1, stride=1, padding=0)
1846                # check if the output shape was still computed correctly
1847                self.assertEqual(x.shape[2], res.shape[2])
1848
1849    @onlyCUDA
1850    @largeTensorTest("6GB")
1851    def test_pooling_large(self, device):
1852        def helper(pool):
1853            inp = torch.randn(
1854                2**7 + 10, 2**8, 2**8, 2**8, dtype=torch.half, device="cuda"
1855            )
1856            self.assertTrue(inp.numel() > 2**31 - 1)
1857            out = pool(inp)
1858            torch.cuda.synchronize()  # asserts test finishes normally without raising errors
1859
1860        helper(nn.MaxPool2d(4, 4))
1861        helper(nn.AvgPool2d(4, 4))
1862        helper(nn.FractionalMaxPool2d(4, 4))
1863        helper(nn.AdaptiveMaxPool2d((2**6, 2**6)))
1864        helper(nn.AdaptiveAvgPool2d((2**6, 2**6)))
1865
1866    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1867    @skipIfMps
1868    @dtypes(torch.float)
1869    def test_pool_invalid_size(self, device, dtype):
1870        for op in ("max", "avg"):
1871            for num_dim in [1, 2, 3]:
1872                fn_name = f"{op}_pool{num_dim}d"
1873                if op == "max":
1874                    # New implementation without indices supports empty tensors
1875                    # TODO(Heitor) change once with_indices code is updated
1876                    fn_name += "_with_indices"
1877                fn = getattr(F, fn_name)
1878                # use a configuration that gives zero outputs only
1879                # when doing a correct floor division by the stride
1880                x = torch.ones([1, 1] + num_dim * [4], device=device, dtype=dtype)
1881                with self.assertRaisesRegex(RuntimeError, r"too small|smaller than"):
1882                    try:
1883                        res = fn(x, 3, stride=2, padding=0, dilation=2)
1884                    except TypeError:
1885                        # some implementations do not support dilation
1886                        res = fn(x, 6, stride=2, padding=0)
1887
1888    @onlyCUDA
1889    def test_pooling_bfloat16(self, device):
1890        _test_bfloat16_ops(
1891            self,
1892            torch.nn.AvgPool1d(3, stride=2),
1893            device,
1894            inp_dims=(8, 4, 16),
1895            prec=0.05,
1896        )
1897        _test_bfloat16_ops(
1898            self,
1899            torch.nn.AvgPool2d(3, stride=2),
1900            device,
1901            inp_dims=(8, 4, 16, 16),
1902            prec=0.05,
1903        )
1904        _test_bfloat16_ops(
1905            self,
1906            torch.nn.AvgPool3d(3, stride=2),
1907            device,
1908            inp_dims=(8, 4, 16, 16, 16),
1909            prec=0.05,
1910        )
1911        _test_bfloat16_ops(
1912            self, torch.nn.AdaptiveAvgPool1d(3), device, inp_dims=(8, 4, 16), prec=0.05
1913        )
1914        _test_bfloat16_ops(
1915            self,
1916            torch.nn.AdaptiveAvgPool2d((3, 5)),
1917            device,
1918            inp_dims=(8, 4, 16, 16),
1919            prec=0.05,
1920        )
1921        _test_bfloat16_ops(
1922            self,
1923            torch.nn.AdaptiveAvgPool3d((3, 5, 7)),
1924            device,
1925            inp_dims=(8, 4, 16, 16, 16),
1926            prec=0.05,
1927        )
1928
1929    def test_maxpool3d_non_square_backward(self, device):
1930        # previous CUDA routine of this backward calculates kernel launch grid size
1931        # with last two dimensions interchanged, so the tailing along the longer dim
1932        # get ignored. Here we test whether every position gets gradient.
1933        for dim in (2, 3, 4):
1934            shape = tuple(32 if i != dim else 256 for i in range(4))
1935            x = torch.randn(shape, device=device, requires_grad=True)
1936            F.max_pool3d(x, kernel_size=(1, 1, 1)).sum().backward()
1937            self.assertEqual(x.grad, torch.ones_like(x.grad))
1938
1939    @slowTest
1940    def test_adaptive_pool_odd_size(self, device):
1941        # See https://github.com/pytorch/pytorch/issues/81409
1942        Ih, Iw, Oh, Ow = 5873, 3693, 3527, 2219
1943        imgs = torch.randint(low=0, high=256, size=(11, Ih, Iw), dtype=torch.float)
1944        imgs_ = F.adaptive_avg_pool2d(imgs, (Oh, Ow))
1945        imgs_ = F.adaptive_max_pool2d(imgs, (Oh, Ow))
1946
1947        Id, Ih, Iw, Od, Oh, Ow = 3, 5873, 3693, 3, 3527, 2219
1948        imgs = torch.randint(low=0, high=256, size=(3, Id, Ih, Iw), dtype=torch.float)
1949        imgs_ = F.adaptive_avg_pool3d(imgs, (Od, Oh, Ow))
1950        imgs_ = F.adaptive_max_pool3d(imgs, (Od, Oh, Ow))
1951
1952
1953instantiate_device_type_tests(TestPoolingNNDeviceType, globals())
1954instantiate_parametrized_tests(TestPoolingNN)
1955
1956if __name__ == "__main__":
1957    run_tests()
1958