xref: /aosp_15_r20/external/pytorch/test/nn/test_pooling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"]
2*da0073e9SAndroid Build Coastguard Workerimport itertools
3*da0073e9SAndroid Build Coastguard Workerimport math
4*da0073e9SAndroid Build Coastguard Workerimport operator
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport random
7*da0073e9SAndroid Build Coastguard Workerimport subprocess
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport unittest
10*da0073e9SAndroid Build Coastguard Workerfrom functools import partial, reduce
11*da0073e9SAndroid Build Coastguard Workerfrom itertools import repeat
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
15*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
16*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
17*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import gradcheck, gradgradcheck
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
21*da0073e9SAndroid Build Coastguard Worker    dtypes,
22*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
23*da0073e9SAndroid Build Coastguard Worker    expectedFailureMeta,
24*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
25*da0073e9SAndroid Build Coastguard Worker    largeTensorTest,
26*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
27*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
28*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
29*da0073e9SAndroid Build Coastguard Worker    skipCUDAIfRocm,
30*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and
33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import (
34*da0073e9SAndroid Build Coastguard Worker    _test_bfloat16_ops,
35*da0073e9SAndroid Build Coastguard Worker    _test_module_empty_input,
36*da0073e9SAndroid Build Coastguard Worker    NNTestCase,
37*da0073e9SAndroid Build Coastguard Worker)
38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
39*da0073e9SAndroid Build Coastguard Worker    gcIfJetson,
40*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
41*da0073e9SAndroid Build Coastguard Worker    parametrize as parametrize_test,
42*da0073e9SAndroid Build Coastguard Worker    run_tests,
43*da0073e9SAndroid Build Coastguard Worker    set_default_dtype,
44*da0073e9SAndroid Build Coastguard Worker    skipIfMps,
45*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
46*da0073e9SAndroid Build Coastguard Worker    slowTest,
47*da0073e9SAndroid Build Coastguard Worker    subtest,
48*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_UBSAN,
49*da0073e9SAndroid Build Coastguard Worker    TestCase,
50*da0073e9SAndroid Build Coastguard Worker)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerclass TestAvgPool(TestCase):
54*da0073e9SAndroid Build Coastguard Worker    def _sum_pool2d(self, x, kernel_size):
55*da0073e9SAndroid Build Coastguard Worker        windows = torch.nn.functional.unfold(
56*da0073e9SAndroid Build Coastguard Worker            x, kernel_size=kernel_size, stride=kernel_size
57*da0073e9SAndroid Build Coastguard Worker        )
58*da0073e9SAndroid Build Coastguard Worker        return torch.sum(windows, dim=1)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    def _sum_pool3d(self, x, kernel_size):
61*da0073e9SAndroid Build Coastguard Worker        # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
62*da0073e9SAndroid Build Coastguard Worker        h = kernel_size[0]
63*da0073e9SAndroid Build Coastguard Worker        splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
64*da0073e9SAndroid Build Coastguard Worker        # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
65*da0073e9SAndroid Build Coastguard Worker        splited_x = [
66*da0073e9SAndroid Build Coastguard Worker            self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:])
67*da0073e9SAndroid Build Coastguard Worker            for t in splited_x
68*da0073e9SAndroid Build Coastguard Worker        ]
69*da0073e9SAndroid Build Coastguard Worker        joined_x = torch.cat(splited_x)
70*da0073e9SAndroid Build Coastguard Worker        return joined_x.view(1, joined_x.numel())
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def _avg_pool2d(self, x, kernel_size):
73*da0073e9SAndroid Build Coastguard Worker        size = reduce(operator.mul, kernel_size)
74*da0073e9SAndroid Build Coastguard Worker        return self._sum_pool2d(x, kernel_size) / size
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def _avg_pool3d(self, x, kernel_size):
77*da0073e9SAndroid Build Coastguard Worker        size = reduce(operator.mul, kernel_size)
78*da0073e9SAndroid Build Coastguard Worker        return self._sum_pool3d(x, kernel_size) / size
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def test_doubletensor_avg_pool2d(self):
81*da0073e9SAndroid Build Coastguard Worker        n, m = 5, 8
82*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 1, n, m, dtype=torch.double)
83*da0073e9SAndroid Build Coastguard Worker        for i in range(1, n + 1):
84*da0073e9SAndroid Build Coastguard Worker            for j in range(1, m + 1):
85*da0073e9SAndroid Build Coastguard Worker                actual = torch.nn.functional.avg_pool2d(input[0], (i, j))
86*da0073e9SAndroid Build Coastguard Worker                actual = actual.view(1, actual.numel())
87*da0073e9SAndroid Build Coastguard Worker                expected = self._avg_pool2d(input, (i, j))
88*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected, rtol=0, atol=1e-5)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def test_doubletensor_avg_pool2d_with_divisor(self):
91*da0073e9SAndroid Build Coastguard Worker        n, m = 3, 3
92*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 1, n, m, dtype=torch.double)
93*da0073e9SAndroid Build Coastguard Worker        for i in range(1, n + 1):
94*da0073e9SAndroid Build Coastguard Worker            for j in range(1, m + 1):
95*da0073e9SAndroid Build Coastguard Worker                for divisor in [1, 7, i * j]:
96*da0073e9SAndroid Build Coastguard Worker                    actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
97*da0073e9SAndroid Build Coastguard Worker                    actual = actual.view(1, actual.numel())
98*da0073e9SAndroid Build Coastguard Worker                    expected = self._sum_pool2d(input, (i, j)) / divisor
99*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def test_doubletensor_avg_pool3d(self):
102*da0073e9SAndroid Build Coastguard Worker        h, w, d = 5, 6, 7
103*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(h, w, d, dtype=torch.double)
104*da0073e9SAndroid Build Coastguard Worker        for i in range(1, h + 1):
105*da0073e9SAndroid Build Coastguard Worker            for j in range(1, w + 1):
106*da0073e9SAndroid Build Coastguard Worker                for k in range(1, d + 1):
107*da0073e9SAndroid Build Coastguard Worker                    actual = torch.nn.functional.avg_pool3d(
108*da0073e9SAndroid Build Coastguard Worker                        input.unsqueeze(0), (i, j, k)
109*da0073e9SAndroid Build Coastguard Worker                    )
110*da0073e9SAndroid Build Coastguard Worker                    actual = actual.view(1, actual.numel())
111*da0073e9SAndroid Build Coastguard Worker                    expected = self._avg_pool3d(input, (i, j, k))
112*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    def test_doubletensor_avg_pool3d_with_divisor(self):
115*da0073e9SAndroid Build Coastguard Worker        h, w, d = 6, 5, 7
116*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(h, w, d, dtype=torch.double)
117*da0073e9SAndroid Build Coastguard Worker        for i in range(1, h + 1):
118*da0073e9SAndroid Build Coastguard Worker            for j in range(1, w + 1):
119*da0073e9SAndroid Build Coastguard Worker                for k in range(1, d + 1):
120*da0073e9SAndroid Build Coastguard Worker                    for divisor in [1, 7, i * j]:
121*da0073e9SAndroid Build Coastguard Worker                        actual = torch.nn.functional.avg_pool3d(
122*da0073e9SAndroid Build Coastguard Worker                            input.unsqueeze(0), (i, j, k), divisor_override=divisor
123*da0073e9SAndroid Build Coastguard Worker                        )
124*da0073e9SAndroid Build Coastguard Worker                        actual = actual.view(1, actual.numel())
125*da0073e9SAndroid Build Coastguard Worker                        expected = self._sum_pool3d(input, (i, j, k)) / divisor
126*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(actual, expected, rtol=0, atol=1e-5)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool1d_ceil_mode(self):
129*da0073e9SAndroid Build Coastguard Worker        # Regression test for gh-36977
130*da0073e9SAndroid Build Coastguard Worker        x = 10 * torch.randn((1, 16, 4))
131*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.avg_pool1d(
132*da0073e9SAndroid Build Coastguard Worker            x, ceil_mode=True, count_include_pad=True, kernel_size=1, stride=2
133*da0073e9SAndroid Build Coastguard Worker        )
134*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not torch.isnan(y).any())
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
137*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.avg_pool1d(
138*da0073e9SAndroid Build Coastguard Worker                x.to("cuda"),
139*da0073e9SAndroid Build Coastguard Worker                ceil_mode=True,
140*da0073e9SAndroid Build Coastguard Worker                count_include_pad=True,
141*da0073e9SAndroid Build Coastguard Worker                kernel_size=1,
142*da0073e9SAndroid Build Coastguard Worker                stride=2,
143*da0073e9SAndroid Build Coastguard Worker            )
144*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not torch.isnan(y).any())
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_ceil_mode(self):
147*da0073e9SAndroid Build Coastguard Worker        # Regression test for gh-36977
148*da0073e9SAndroid Build Coastguard Worker        x = 10 * torch.randn((1, 16, 4, 4))
149*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.avg_pool2d(
150*da0073e9SAndroid Build Coastguard Worker            x,
151*da0073e9SAndroid Build Coastguard Worker            ceil_mode=True,
152*da0073e9SAndroid Build Coastguard Worker            count_include_pad=True,
153*da0073e9SAndroid Build Coastguard Worker            kernel_size=(1, 2),
154*da0073e9SAndroid Build Coastguard Worker            padding=(0, 1),
155*da0073e9SAndroid Build Coastguard Worker            stride=2,
156*da0073e9SAndroid Build Coastguard Worker        )
157*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not torch.isnan(y).any())
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
160*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.avg_pool2d(
161*da0073e9SAndroid Build Coastguard Worker                x.to("cuda"),
162*da0073e9SAndroid Build Coastguard Worker                ceil_mode=True,
163*da0073e9SAndroid Build Coastguard Worker                count_include_pad=True,
164*da0073e9SAndroid Build Coastguard Worker                kernel_size=(1, 2),
165*da0073e9SAndroid Build Coastguard Worker                padding=(0, 1),
166*da0073e9SAndroid Build Coastguard Worker                stride=2,
167*da0073e9SAndroid Build Coastguard Worker            )
168*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not torch.isnan(y).any())
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool3d_ceil_mode(self):
171*da0073e9SAndroid Build Coastguard Worker        # Regression test for gh-36977
172*da0073e9SAndroid Build Coastguard Worker        x = 10 * torch.randn((1, 16, 4, 4, 4))
173*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.avg_pool3d(
174*da0073e9SAndroid Build Coastguard Worker            x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2, 3), stride=2
175*da0073e9SAndroid Build Coastguard Worker        )
176*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not torch.isnan(y).any())
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
179*da0073e9SAndroid Build Coastguard Worker            y = torch.nn.functional.avg_pool3d(
180*da0073e9SAndroid Build Coastguard Worker                x.to("cuda"),
181*da0073e9SAndroid Build Coastguard Worker                ceil_mode=True,
182*da0073e9SAndroid Build Coastguard Worker                count_include_pad=True,
183*da0073e9SAndroid Build Coastguard Worker                kernel_size=(1, 2, 3),
184*da0073e9SAndroid Build Coastguard Worker                stride=2,
185*da0073e9SAndroid Build Coastguard Worker            )
186*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not torch.isnan(y).any())
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Workerclass TestPoolingNN(NNTestCase):
190*da0073e9SAndroid Build Coastguard Worker    _do_cuda_memory_leak_check = True
191*da0073e9SAndroid Build Coastguard Worker    _do_cuda_non_default_stream = True
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_size_none(self):
194*da0073e9SAndroid Build Coastguard Worker        for numel in (2, 3):
195*da0073e9SAndroid Build Coastguard Worker            for pool_type in ("Max", "Avg"):
196*da0073e9SAndroid Build Coastguard Worker                cls_name = f"Adaptive{pool_type}Pool{numel}d"
197*da0073e9SAndroid Build Coastguard Worker                module_cls = getattr(nn, cls_name)
198*da0073e9SAndroid Build Coastguard Worker                output_size = (2,) * (numel - 1) + (None,)
199*da0073e9SAndroid Build Coastguard Worker                module = module_cls(output_size)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker                input = torch.randn((4,) * (numel + 1))
202*da0073e9SAndroid Build Coastguard Worker                output = module(input)
203*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output.size(), (4,) + (2,) * (numel - 1) + (4,))
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_UBSAN, "signed integer overflow error with UBSAN")
206*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_size_overflow(self):
207*da0073e9SAndroid Build Coastguard Worker        # 0x0x3fffffffffffffff * 2 * 2 = 0xfffffffffffffffc = -4 as int64_t
208*da0073e9SAndroid Build Coastguard Worker        # Tensor::numel() return int64_t, so following check that negative allocs are correctly handled
209*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
210*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
211*da0073e9SAndroid Build Coastguard Worker            lambda: torch.nn.AdaptiveMaxPool1d(0x3FFFFFFFFFFFFFFF)(
212*da0073e9SAndroid Build Coastguard Worker                torch.empty([2, 2, 2])
213*da0073e9SAndroid Build Coastguard Worker            ),
214*da0073e9SAndroid Build Coastguard Worker        )
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_avg_nhwc(self):
217*da0073e9SAndroid Build Coastguard Worker        device_list = ["cpu"]
218*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
219*da0073e9SAndroid Build Coastguard Worker            device_list.append("cuda")
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        for device in device_list:
222*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device)
223*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
224*da0073e9SAndroid Build Coastguard Worker            grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device)
225*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
228*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
229*da0073e9SAndroid Build Coastguard Worker            ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker            out = pool(input)
232*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
233*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_pool(ref_input)
234*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
237*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
238*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
239*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_avg_nhwc_non_contiguous(self):
242*da0073e9SAndroid Build Coastguard Worker        device_list = ["cpu"]
243*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA:
244*da0073e9SAndroid Build Coastguard Worker            device_list.append("cuda")
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        for device in device_list:
247*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device)
248*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=torch.channels_last)
249*da0073e9SAndroid Build Coastguard Worker            input = input[:, ::2, :, :].requires_grad_()
250*da0073e9SAndroid Build Coastguard Worker            grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device)
251*da0073e9SAndroid Build Coastguard Worker            grad = grad[:, ::2, :, :]
252*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
255*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
256*da0073e9SAndroid Build Coastguard Worker            ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker            out = pool(input)
259*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
260*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_pool(ref_input)
261*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
264*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
265*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
266*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_lower_precision(self):
269*da0073e9SAndroid Build Coastguard Worker        def _test_adaptive_pooling_lower_precision(
270*da0073e9SAndroid Build Coastguard Worker            self, device, dtype, mod, memory_format
271*da0073e9SAndroid Build Coastguard Worker        ):
272*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(1, 10, (3, 19, 8, 8), dtype=torch.float32)
273*da0073e9SAndroid Build Coastguard Worker            input = input.to(device).to(memory_format=memory_format).requires_grad_()
274*da0073e9SAndroid Build Coastguard Worker            pool = mod((7, 7)).to(device)
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().to(dtype=dtype).requires_grad_(True)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker            out = pool(input)
279*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
280*da0073e9SAndroid Build Coastguard Worker            out2 = pool(input2)
281*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out2.is_contiguous(memory_format=memory_format))
284*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out2.dtype, dtype)
285*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input2.grad.dtype, dtype)
286*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2.float(), atol=0.1, rtol=0)
287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input2.grad.float(), atol=0.1, rtol=0)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        device_list = ["cpu"]
290*da0073e9SAndroid Build Coastguard Worker        for device in device_list:
291*da0073e9SAndroid Build Coastguard Worker            for dtype in [torch.bfloat16, torch.float16]:
292*da0073e9SAndroid Build Coastguard Worker                _test_adaptive_pooling_lower_precision(
293*da0073e9SAndroid Build Coastguard Worker                    self,
294*da0073e9SAndroid Build Coastguard Worker                    device,
295*da0073e9SAndroid Build Coastguard Worker                    dtype,
296*da0073e9SAndroid Build Coastguard Worker                    torch.nn.AdaptiveAvgPool2d,
297*da0073e9SAndroid Build Coastguard Worker                    torch.contiguous_format,
298*da0073e9SAndroid Build Coastguard Worker                )
299*da0073e9SAndroid Build Coastguard Worker                _test_adaptive_pooling_lower_precision(
300*da0073e9SAndroid Build Coastguard Worker                    self, device, dtype, torch.nn.AdaptiveAvgPool2d, torch.channels_last
301*da0073e9SAndroid Build Coastguard Worker                )
302*da0073e9SAndroid Build Coastguard Worker                _test_adaptive_pooling_lower_precision(
303*da0073e9SAndroid Build Coastguard Worker                    self,
304*da0073e9SAndroid Build Coastguard Worker                    device,
305*da0073e9SAndroid Build Coastguard Worker                    dtype,
306*da0073e9SAndroid Build Coastguard Worker                    torch.nn.AdaptiveMaxPool2d,
307*da0073e9SAndroid Build Coastguard Worker                    torch.contiguous_format,
308*da0073e9SAndroid Build Coastguard Worker                )
309*da0073e9SAndroid Build Coastguard Worker                _test_adaptive_pooling_lower_precision(
310*da0073e9SAndroid Build Coastguard Worker                    self, device, dtype, torch.nn.AdaptiveMaxPool2d, torch.channels_last
311*da0073e9SAndroid Build Coastguard Worker                )
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
314*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("12GB", device="cuda")
315*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_avg_nhwc_launch_config_backward(self):
316*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(
317*da0073e9SAndroid Build Coastguard Worker            1, 10, (1, 32, 2**17 + 1, 32), dtype=torch.float32, device="cuda"
318*da0073e9SAndroid Build Coastguard Worker        )
319*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
320*da0073e9SAndroid Build Coastguard Worker        grad = torch.randint(1, 10, (1, 32, 10, 32), dtype=torch.float32, device="cuda")
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda()
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        ref_input = input.detach().clone().contiguous().requires_grad_(True)
325*da0073e9SAndroid Build Coastguard Worker        ref_grad = grad.detach().clone().contiguous()
326*da0073e9SAndroid Build Coastguard Worker        ref_pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda()
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        out = pool(input)
329*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
330*da0073e9SAndroid Build Coastguard Worker        ref_out = ref_pool(ref_input)
331*da0073e9SAndroid Build Coastguard Worker        ref_out.backward(ref_grad)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
334*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref_out.is_contiguous())
335*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input.grad, ref_input.grad)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
339*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("12GB", device="cuda")
340*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_avg_nhwc_launch_config_forward(self):
341*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(
342*da0073e9SAndroid Build Coastguard Worker            1, 10, (1, 32, 16, 16), dtype=torch.float32, device="cuda"
343*da0073e9SAndroid Build Coastguard Worker        )
344*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
345*da0073e9SAndroid Build Coastguard Worker        pool = torch.nn.AdaptiveAvgPool2d((2**17 + 1, 32)).cuda()
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker        ref_input = input.detach().clone().contiguous().requires_grad_(True)
348*da0073e9SAndroid Build Coastguard Worker        ref_pool = torch.nn.AdaptiveAvgPool2d((2**17 + 1, 32)).cuda()
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        out = pool(input)
351*da0073e9SAndroid Build Coastguard Worker        ref_out = ref_pool(ref_input)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
354*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref_out.is_contiguous())
355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
358*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pooling_overflow(self):
359*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(
360*da0073e9SAndroid Build Coastguard Worker            -256, 256, (20, 32, 256, 256), dtype=torch.half, device="cuda"
361*da0073e9SAndroid Build Coastguard Worker        )
362*da0073e9SAndroid Build Coastguard Worker        avg_pool = torch.nn.AdaptiveAvgPool2d((2, 2))
363*da0073e9SAndroid Build Coastguard Worker        out = avg_pool(input)
364*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isinf(out).any())
365*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isnan(out).any())
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
368*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pooling_nhwc_overflow(self):
369*da0073e9SAndroid Build Coastguard Worker        input = torch.randint(
370*da0073e9SAndroid Build Coastguard Worker            -256, 256, (20, 32, 256, 256), dtype=torch.half, device="cuda"
371*da0073e9SAndroid Build Coastguard Worker        )
372*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous(memory_format=torch.channels_last)
373*da0073e9SAndroid Build Coastguard Worker        avg_pool = torch.nn.AdaptiveAvgPool2d((2, 2))
374*da0073e9SAndroid Build Coastguard Worker        out = avg_pool(input)
375*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isinf(out).any())
376*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.isnan(out).any())
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    def test_MaxUnpool2d_output_size(self):
379*da0073e9SAndroid Build Coastguard Worker        m = nn.MaxPool2d(3, stride=2, return_indices=True)
380*da0073e9SAndroid Build Coastguard Worker        mu = nn.MaxUnpool2d(3, stride=2)
381*da0073e9SAndroid Build Coastguard Worker        big_t = torch.rand(1, 1, 6, 6)
382*da0073e9SAndroid Build Coastguard Worker        big_t[0][0][4][4] = 100
383*da0073e9SAndroid Build Coastguard Worker        output_big, indices_big = m(big_t)
384*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        small_t = torch.rand(1, 1, 5, 5)
387*da0073e9SAndroid Build Coastguard Worker        for i in range(0, 4, 2):
388*da0073e9SAndroid Build Coastguard Worker            for j in range(0, 4, 2):
389*da0073e9SAndroid Build Coastguard Worker                small_t[:, :, i, j] = 100
390*da0073e9SAndroid Build Coastguard Worker        output_small, indices_small = m(small_t)
391*da0073e9SAndroid Build Coastguard Worker        for h in range(3, 10):
392*da0073e9SAndroid Build Coastguard Worker            for w in range(3, 10):
393*da0073e9SAndroid Build Coastguard Worker                if 4 <= h <= 6 and 4 <= w <= 6:
394*da0073e9SAndroid Build Coastguard Worker                    size = (h, w)
395*da0073e9SAndroid Build Coastguard Worker                    if h == 6:
396*da0073e9SAndroid Build Coastguard Worker                        size = (1, 1) + size
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker                    mu(output_small, indices_small, output_size=size)
399*da0073e9SAndroid Build Coastguard Worker                else:
400*da0073e9SAndroid Build Coastguard Worker                    self.assertRaises(
401*da0073e9SAndroid Build Coastguard Worker                        ValueError, lambda: mu(output_small, indices_small, (h, w))
402*da0073e9SAndroid Build Coastguard Worker                    )
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def test_max_unpool2d_nhwc_cpu(self):
405*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 10, 9, 9).float().cpu()
406*da0073e9SAndroid Build Coastguard Worker        input = input.contiguous(memory_format=torch.channels_last)
407*da0073e9SAndroid Build Coastguard Worker        ref_input = input.clone().contiguous()
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker        pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
410*da0073e9SAndroid Build Coastguard Worker        ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
411*da0073e9SAndroid Build Coastguard Worker
412*da0073e9SAndroid Build Coastguard Worker        out, ind = pool(input)
413*da0073e9SAndroid Build Coastguard Worker        ref_out, ref_ind = ref_pool(ref_input)
414*da0073e9SAndroid Build Coastguard Worker        out.requires_grad_()
415*da0073e9SAndroid Build Coastguard Worker        ref_out.requires_grad_()
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        unpool = nn.MaxUnpool2d(3, stride=2).cpu()
418*da0073e9SAndroid Build Coastguard Worker        ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu()
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        upout = unpool(out, ind)
421*da0073e9SAndroid Build Coastguard Worker        ref_upout = ref_unpool(ref_out, ref_ind)
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(upout.size()).float().cpu()
424*da0073e9SAndroid Build Coastguard Worker        grad = grad.contiguous(memory_format=torch.channels_last)
425*da0073e9SAndroid Build Coastguard Worker        ref_grad = grad.clone().contiguous()
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker        upout.backward(grad)
428*da0073e9SAndroid Build Coastguard Worker        ref_upout.backward(ref_grad)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last))
431*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref_upout.is_contiguous())
432*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(upout, ref_upout))
433*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(out.grad, ref_out.grad))
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker    def test_max_unpool(self):
436*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
437*da0073e9SAndroid Build Coastguard Worker            # Test 1D
438*da0073e9SAndroid Build Coastguard Worker            output, indices = F.max_pool1d(
439*da0073e9SAndroid Build Coastguard Worker                torch.randn([1, 1, 4]), 2, stride=2, return_indices=True
440*da0073e9SAndroid Build Coastguard Worker            )
441*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
442*da0073e9SAndroid Build Coastguard Worker                F.max_unpool1d(output, indices, 2),
443*da0073e9SAndroid Build Coastguard Worker                F.max_unpool1d(output, indices, 2, stride=2),
444*da0073e9SAndroid Build Coastguard Worker            )
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker            # Test list / tuple passed as argument to max_unpool1d
447*da0073e9SAndroid Build Coastguard Worker            input = torch.randn([1, 1, 5], requires_grad=True)
448*da0073e9SAndroid Build Coastguard Worker            output, indices = F.max_pool1d(input, 2, stride=2, return_indices=True)
449*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
450*da0073e9SAndroid Build Coastguard Worker                F.max_unpool1d(output, indices, 2, stride=2, output_size=input.shape),
451*da0073e9SAndroid Build Coastguard Worker                F.max_unpool1d(output, indices, 2, stride=2, output_size=input.size()),
452*da0073e9SAndroid Build Coastguard Worker            )
453*da0073e9SAndroid Build Coastguard Worker            gradcheck(F.max_unpool1d, (output, indices, 2), check_forward_ad=True)
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker            # Test 2D
456*da0073e9SAndroid Build Coastguard Worker            output, indices = F.max_pool2d(
457*da0073e9SAndroid Build Coastguard Worker                torch.randn([1, 1, 4, 4], requires_grad=True),
458*da0073e9SAndroid Build Coastguard Worker                2,
459*da0073e9SAndroid Build Coastguard Worker                stride=2,
460*da0073e9SAndroid Build Coastguard Worker                return_indices=True,
461*da0073e9SAndroid Build Coastguard Worker            )
462*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
463*da0073e9SAndroid Build Coastguard Worker                F.max_unpool2d(output, indices, 2),
464*da0073e9SAndroid Build Coastguard Worker                F.max_unpool2d(output, indices, 2, stride=2),
465*da0073e9SAndroid Build Coastguard Worker            )
466*da0073e9SAndroid Build Coastguard Worker            gradcheck(F.max_unpool2d, (output, indices, 2), check_forward_ad=True)
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker            # Test 3D
469*da0073e9SAndroid Build Coastguard Worker            output, indices = F.max_pool3d(
470*da0073e9SAndroid Build Coastguard Worker                torch.randn([4, 4, 4, 4, 4], requires_grad=True),
471*da0073e9SAndroid Build Coastguard Worker                2,
472*da0073e9SAndroid Build Coastguard Worker                stride=2,
473*da0073e9SAndroid Build Coastguard Worker                return_indices=True,
474*da0073e9SAndroid Build Coastguard Worker            )
475*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
476*da0073e9SAndroid Build Coastguard Worker                F.max_unpool3d(output, indices, 2),
477*da0073e9SAndroid Build Coastguard Worker                F.max_unpool3d(output, indices, 2, stride=2),
478*da0073e9SAndroid Build Coastguard Worker            )
479*da0073e9SAndroid Build Coastguard Worker            gradcheck(F.max_unpool3d, (output, indices, 2), check_forward_ad=True)
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker    def test_max_unpool3d_input_check(self):
482*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, 3, 1, 1, 1)
483*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
484*da0073e9SAndroid Build Coastguard Worker            F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1])
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker    def test_quantized_max_pool1d_empty_kernel(self):
487*da0073e9SAndroid Build Coastguard Worker        # This used to segfault when called with an empty kernel
488*da0073e9SAndroid Build Coastguard Worker        # see https://github.com/pytorch/pytorch/issues/116323
489*da0073e9SAndroid Build Coastguard Worker        base = torch.randn(1)
490*da0073e9SAndroid Build Coastguard Worker        temp_tensor = torch.quantize_per_tensor(base, 0.1, 10, torch.quint2x4)
491*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
492*da0073e9SAndroid Build Coastguard Worker            torch.quantized_max_pool1d(temp_tensor, [])
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Workerclass TestPoolingNNDeviceType(NNTestCase):
496*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
497*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
498*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_zero_batch(self, dtype, device):
499*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 10, dtype=dtype, device=device)
500*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.AdaptiveAvgPool1d(5).to(device)
501*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 10, 10, dtype=dtype, device=device)
504*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.AdaptiveAvgPool2d((5, 5)).to(device)
505*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 10, 10, 10, dtype=dtype, device=device)
508*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.AdaptiveAvgPool3d((5, 5, 5)).to(device)
509*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    # The tests are used to verify the functions raises errors for backward propagation
512*da0073e9SAndroid Build Coastguard Worker    # when output_size = 0, in adaptive_{avg, max}_pool and its variants.
513*da0073e9SAndroid Build Coastguard Worker    # These tests are explicitly written because ErrorInputs does not support backward calls
514*da0073e9SAndroid Build Coastguard Worker    # Issue: https://github.com/pytorch/pytorch/issues/78868
515*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
516*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.float64)
517*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16)
518*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_empty_output_size(self, dtype, device):
519*da0073e9SAndroid Build Coastguard Worker        error_msg = (
520*da0073e9SAndroid Build Coastguard Worker            "Expected grad_output to have non-zero size for non-batch dimensions"
521*da0073e9SAndroid Build Coastguard Worker        )
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
524*da0073e9SAndroid Build Coastguard Worker        input = make_arg((1, 64, 10, 9))
525*da0073e9SAndroid Build Coastguard Worker        output_size = 0
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker        fns = (
528*da0073e9SAndroid Build Coastguard Worker            nn.functional.adaptive_avg_pool2d,
529*da0073e9SAndroid Build Coastguard Worker            nn.functional.adaptive_avg_pool3d,
530*da0073e9SAndroid Build Coastguard Worker            nn.functional.adaptive_max_pool2d,
531*da0073e9SAndroid Build Coastguard Worker            nn.functional.adaptive_max_pool3d,
532*da0073e9SAndroid Build Coastguard Worker        )
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        for fn in fns:
535*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, error_msg):
536*da0073e9SAndroid Build Coastguard Worker                fn(input, output_size).sum().backward()
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        fns2 = (
539*da0073e9SAndroid Build Coastguard Worker            nn.functional.adaptive_avg_pool1d,
540*da0073e9SAndroid Build Coastguard Worker            nn.functional.adaptive_max_pool1d,
541*da0073e9SAndroid Build Coastguard Worker        )
542*da0073e9SAndroid Build Coastguard Worker        input2 = make_arg((1, 64))
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        for fn in fns2:
545*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, error_msg):
546*da0073e9SAndroid Build Coastguard Worker                fn(input2, output_size).sum().backward()
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
549*da0073e9SAndroid Build Coastguard Worker    def test_FractionalMaxPool2d_zero_batch(self, device):
550*da0073e9SAndroid Build Coastguard Worker        mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
551*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 16, 50, 32, device=device)
552*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected input"):
555*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(1, 0, 50, 32, device=device)
556*da0073e9SAndroid Build Coastguard Worker            mod(inp)
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
559*da0073e9SAndroid Build Coastguard Worker    def test_FractionalMaxPool3d_zero_batch(self, device):
560*da0073e9SAndroid Build Coastguard Worker        mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device)
561*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 16, 50, 32, 32, device=device)
562*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected input"):
565*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(1, 0, 50, 32, 32, device=device)
566*da0073e9SAndroid Build Coastguard Worker            mod(inp)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
569*da0073e9SAndroid Build Coastguard Worker    def test_FractionalMaxPool2d_zero_out_size(self, device):
570*da0073e9SAndroid Build Coastguard Worker        mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1])
571*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([16, 50, 32, 32], device=device)
572*da0073e9SAndroid Build Coastguard Worker        out = mod(inp)
573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device))
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
576*da0073e9SAndroid Build Coastguard Worker    def test_FractionalMaxPool3d_zero_out_size(self, device):
577*da0073e9SAndroid Build Coastguard Worker        mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1])
578*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand([16, 50, 32, 32], device=device)
579*da0073e9SAndroid Build Coastguard Worker        out = mod(inp)
580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device))
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
583*da0073e9SAndroid Build Coastguard Worker    def test_FractionalMaxPool2d_zero_samples(self, device):
584*da0073e9SAndroid Build Coastguard Worker        samples = torch.rand([0, 16, 2], device=device)
585*da0073e9SAndroid Build Coastguard Worker        mod = nn.FractionalMaxPool2d(
586*da0073e9SAndroid Build Coastguard Worker            [2, 2], output_size=[1, 1], _random_samples=samples
587*da0073e9SAndroid Build Coastguard Worker        )
588*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn([0, 16, 32, 32], device=device)
589*da0073e9SAndroid Build Coastguard Worker        out = mod(inp)
590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.empty((0, 16, 1, 1), device=device))
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.randn([1, 16, 32, 32], device=device)
593*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"):
594*da0073e9SAndroid Build Coastguard Worker            out1 = mod(inp1)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
597*da0073e9SAndroid Build Coastguard Worker    def test_FractionalMaxPool3d_zero_samples(self, device):
598*da0073e9SAndroid Build Coastguard Worker        samples = torch.rand([0, 16, 3], device=device)
599*da0073e9SAndroid Build Coastguard Worker        mod = nn.FractionalMaxPool3d(
600*da0073e9SAndroid Build Coastguard Worker            [3, 2, 2], output_size=[1, 1, 1], _random_samples=samples
601*da0073e9SAndroid Build Coastguard Worker        )
602*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn([0, 16, 50, 32, 32], device=device)
603*da0073e9SAndroid Build Coastguard Worker        out = mod(inp)
604*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, torch.empty((0, 16, 1, 1, 1), device=device))
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker        inp1 = torch.randn([1, 16, 50, 32, 32], device=device)
607*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"):
608*da0073e9SAndroid Build Coastguard Worker            out1 = mod(inp1)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
611*da0073e9SAndroid Build Coastguard Worker    def test_MaxPool_zero_batch_dim(self, device):
612*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 16, 50, device=device)
613*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.MaxPool1d(3, stride=2).to(device)
614*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker        # 1D is supposed to be okay with 0 numel() inputs so dont test
617*da0073e9SAndroid Build Coastguard Worker        # error raising for that case.
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 16, 50, 32, device=device)
620*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.MaxPool2d(3, stride=2).to(device)
621*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected"):
624*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(1, 0, 50, 32, device=device)
625*da0073e9SAndroid Build Coastguard Worker            mod(inp)
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 16, 50, 44, 31, device=device)
628*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.MaxPool3d(3, stride=2).to(device)
629*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected"):
632*da0073e9SAndroid Build Coastguard Worker            inp = torch.ones(1, 0, 50, 44, 31, device=device)
633*da0073e9SAndroid Build Coastguard Worker            mod(inp)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
636*da0073e9SAndroid Build Coastguard Worker    def test_MaxUnpool_zero_batch_dim(self, device):
637*da0073e9SAndroid Build Coastguard Worker        pool = torch.nn.MaxPool1d(2, stride=2, return_indices=True).to(device)
638*da0073e9SAndroid Build Coastguard Worker        unpool = torch.nn.MaxUnpool1d(2, stride=2).to(device)
639*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 10, 10, requires_grad=True, device=device)
640*da0073e9SAndroid Build Coastguard Worker        output, indices = pool(inp)
641*da0073e9SAndroid Build Coastguard Worker        output.requires_grad_(True)
642*da0073e9SAndroid Build Coastguard Worker        unpool_out = unpool(output, indices)
643*da0073e9SAndroid Build Coastguard Worker        unpool_out.sum().backward()
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp.grad, torch.zeros_like(inp))
646*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker        pool = torch.nn.MaxPool2d(2, stride=2, return_indices=True).to(device)
649*da0073e9SAndroid Build Coastguard Worker        unpool = torch.nn.MaxUnpool2d(2, stride=2).to(device)
650*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 10, 10, 10, requires_grad=True, device=device)
651*da0073e9SAndroid Build Coastguard Worker        output, indices = pool(inp)
652*da0073e9SAndroid Build Coastguard Worker        unpool_out = unpool(output, indices)
653*da0073e9SAndroid Build Coastguard Worker        unpool_out.sum().backward()
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp.grad, torch.zeros_like(inp))
656*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker        pool = torch.nn.MaxPool3d(2, stride=2, return_indices=True).to(device)
659*da0073e9SAndroid Build Coastguard Worker        unpool = torch.nn.MaxUnpool3d(2, stride=2).to(device)
660*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 10, 10, 10, 10, requires_grad=True, device=device)
661*da0073e9SAndroid Build Coastguard Worker        output, indices = pool(inp)
662*da0073e9SAndroid Build Coastguard Worker        output.requires_grad_(True)
663*da0073e9SAndroid Build Coastguard Worker        unpool_out = unpool(output, indices)
664*da0073e9SAndroid Build Coastguard Worker        unpool_out.sum().backward()
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inp.grad, torch.zeros_like(inp))
667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    @slowTest
670*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
671*da0073e9SAndroid Build Coastguard Worker    @skipCUDAIfRocm
672*da0073e9SAndroid Build Coastguard Worker    @parametrize_test(
673*da0073e9SAndroid Build Coastguard Worker        "module_name,module_size,output_size,test_index,should_error",
674*da0073e9SAndroid Build Coastguard Worker        [
675*da0073e9SAndroid Build Coastguard Worker            # Some tests are failing in trunk https://github.com/pytorch/pytorch/issues/103854
676*da0073e9SAndroid Build Coastguard Worker            subtest(
677*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), -1, True),
678*da0073e9SAndroid Build Coastguard Worker                name="case1",
679*da0073e9SAndroid Build Coastguard Worker            ),
680*da0073e9SAndroid Build Coastguard Worker            subtest(
681*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), 2 * 2 * 4 * 5, True),
682*da0073e9SAndroid Build Coastguard Worker                name="case2",
683*da0073e9SAndroid Build Coastguard Worker            ),
684*da0073e9SAndroid Build Coastguard Worker            subtest(
685*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), (2 * 2 * 4 * 5) - 1, False),
686*da0073e9SAndroid Build Coastguard Worker                name="case3",
687*da0073e9SAndroid Build Coastguard Worker            ),
688*da0073e9SAndroid Build Coastguard Worker            subtest(
689*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool2d", (2, 3), (2, 1, 4, 2), 2 * 3 * 4 * 2, True),
690*da0073e9SAndroid Build Coastguard Worker                name="case4",
691*da0073e9SAndroid Build Coastguard Worker            ),
692*da0073e9SAndroid Build Coastguard Worker            subtest(
693*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool2d", (2, 3), (2, 1, 4, 2), (2 * 3 * 4 * 2) - 1, False),
694*da0073e9SAndroid Build Coastguard Worker                name="case5",
695*da0073e9SAndroid Build Coastguard Worker            ),
696*da0073e9SAndroid Build Coastguard Worker            subtest(
697*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool3d", (2, 2, 2), (1, 3, 4, 5), -1, True),
698*da0073e9SAndroid Build Coastguard Worker                name="case6",
699*da0073e9SAndroid Build Coastguard Worker            ),
700*da0073e9SAndroid Build Coastguard Worker            subtest(
701*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool3d", (2, 2, 2), (1, 3, 4, 5), 2 * 2 * 2 * 3 * 4 * 5, True),
702*da0073e9SAndroid Build Coastguard Worker                name="case7",
703*da0073e9SAndroid Build Coastguard Worker            ),
704*da0073e9SAndroid Build Coastguard Worker            subtest(
705*da0073e9SAndroid Build Coastguard Worker                (
706*da0073e9SAndroid Build Coastguard Worker                    "MaxUnpool3d",
707*da0073e9SAndroid Build Coastguard Worker                    (2, 2, 2),
708*da0073e9SAndroid Build Coastguard Worker                    (1, 3, 4, 5),
709*da0073e9SAndroid Build Coastguard Worker                    (2 * 2 * 2 * 3 * 4 * 5) - 1,
710*da0073e9SAndroid Build Coastguard Worker                    False,
711*da0073e9SAndroid Build Coastguard Worker                ),
712*da0073e9SAndroid Build Coastguard Worker                name="case8",
713*da0073e9SAndroid Build Coastguard Worker            ),
714*da0073e9SAndroid Build Coastguard Worker            subtest(
715*da0073e9SAndroid Build Coastguard Worker                ("MaxUnpool3d", (2, 2, 2), (2, 3, 4, 1), 2 * 2 * 2 * 3 * 4 * 1, True),
716*da0073e9SAndroid Build Coastguard Worker                name="case9",
717*da0073e9SAndroid Build Coastguard Worker            ),
718*da0073e9SAndroid Build Coastguard Worker            subtest(
719*da0073e9SAndroid Build Coastguard Worker                (
720*da0073e9SAndroid Build Coastguard Worker                    "MaxUnpool3d",
721*da0073e9SAndroid Build Coastguard Worker                    (2, 2, 2),
722*da0073e9SAndroid Build Coastguard Worker                    (2, 3, 4, 1),
723*da0073e9SAndroid Build Coastguard Worker                    (2 * 2 * 2 * 3 * 4 * 1) - 1,
724*da0073e9SAndroid Build Coastguard Worker                    False,
725*da0073e9SAndroid Build Coastguard Worker                ),
726*da0073e9SAndroid Build Coastguard Worker                name="case10",
727*da0073e9SAndroid Build Coastguard Worker            ),
728*da0073e9SAndroid Build Coastguard Worker        ],
729*da0073e9SAndroid Build Coastguard Worker    )
730*da0073e9SAndroid Build Coastguard Worker    def test_MaxUnpool_index_errors(
731*da0073e9SAndroid Build Coastguard Worker        self, device, module_name, module_size, output_size, test_index, should_error
732*da0073e9SAndroid Build Coastguard Worker    ):
733*da0073e9SAndroid Build Coastguard Worker        # NOTE: CUDA tests need to be run in a subprocess because they cause device asserts
734*da0073e9SAndroid Build Coastguard Worker        if torch.device(device).type == "cuda":
735*da0073e9SAndroid Build Coastguard Worker            error_msgs = {
736*da0073e9SAndroid Build Coastguard Worker                "MaxUnpool2d": r"Assertion `maxind >= 0 && maxind < outputImageSize` failed",
737*da0073e9SAndroid Build Coastguard Worker                "MaxUnpool3d": r"Assertion `index >= 0 && index < outputImageSize` failed",
738*da0073e9SAndroid Build Coastguard Worker            }
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker            script = f"""
741*da0073e9SAndroid Build Coastguard Workerimport torch
742*da0073e9SAndroid Build Coastguard Workerunpool = torch.nn.{module_name}({module_size}).to('{device}')
743*da0073e9SAndroid Build Coastguard Workeroutput = torch.rand({output_size}, dtype=torch.float32, device='{device}')
744*da0073e9SAndroid Build Coastguard Workerindices = torch.zeros({output_size}, dtype=torch.int64, device='{device}')
745*da0073e9SAndroid Build Coastguard Workerindices.flatten()[0] = {test_index}
746*da0073e9SAndroid Build Coastguard Workerunpool(output, indices)
747*da0073e9SAndroid Build Coastguard Workertorch.cuda.synchronize()
748*da0073e9SAndroid Build Coastguard Worker"""
749*da0073e9SAndroid Build Coastguard Worker            p = subprocess.run(
750*da0073e9SAndroid Build Coastguard Worker                [sys.executable, "-c", script],
751*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),
752*da0073e9SAndroid Build Coastguard Worker                capture_output=True,
753*da0073e9SAndroid Build Coastguard Worker                text=True,
754*da0073e9SAndroid Build Coastguard Worker            )
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker            output = p.stdout + "\n" + p.stderr
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker            error_msg = error_msgs[module_name]
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker            if should_error:
761*da0073e9SAndroid Build Coastguard Worker                self.assertIn(error_msg, output, "The expected error was not found")
762*da0073e9SAndroid Build Coastguard Worker            else:
763*da0073e9SAndroid Build Coastguard Worker                self.assertNotIn("Error", output, "Should not have produced an error")
764*da0073e9SAndroid Build Coastguard Worker        else:
765*da0073e9SAndroid Build Coastguard Worker            module_class = getattr(torch.nn, module_name)
766*da0073e9SAndroid Build Coastguard Worker            unpool = module_class(module_size).to(device)
767*da0073e9SAndroid Build Coastguard Worker            output = torch.rand(output_size, dtype=torch.float32, device=device)
768*da0073e9SAndroid Build Coastguard Worker            indices = torch.zeros(output_size, dtype=torch.int64, device=device)
769*da0073e9SAndroid Build Coastguard Worker            indices.flatten()[0] = test_index
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker            if should_error:
772*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
773*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, r"Found an invalid max index:"
774*da0073e9SAndroid Build Coastguard Worker                ):
775*da0073e9SAndroid Build Coastguard Worker                    unpool(output, indices)
776*da0073e9SAndroid Build Coastguard Worker            else:
777*da0073e9SAndroid Build Coastguard Worker                unpool(output, indices)
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
780*da0073e9SAndroid Build Coastguard Worker    def test_AdaptiveMaxPool_zero_batch_dim(self, device):
781*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 16, 50, device=device)
782*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.AdaptiveMaxPool1d(3).to(device)
783*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected"):
786*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(1, 0, 50, device=device)
787*da0073e9SAndroid Build Coastguard Worker            mod(inp)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 16, 50, 32, device=device)
790*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.AdaptiveMaxPool2d(3).to(device)
791*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected"):
794*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(1, 0, 50, 32, device=device)
795*da0073e9SAndroid Build Coastguard Worker            mod(inp)
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(0, 16, 50, 44, 31, device=device)
798*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.AdaptiveMaxPool3d(3).to(device)
799*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, mod, inp, check_size=False)
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Expected"):
802*da0073e9SAndroid Build Coastguard Worker            inp = torch.ones(1, 0, 50, 44, 31, device=device)
803*da0073e9SAndroid Build Coastguard Worker            mod(inp)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
806*da0073e9SAndroid Build Coastguard Worker    def test_AvgPool2d_empty(self, device):
807*da0073e9SAndroid Build Coastguard Worker        avgpool = torch.nn.AvgPool2d(3, stride=2).to(device)
808*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(0, 16, 20, 32, device=device)
809*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, avgpool, inp, check_size=False)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker        clast_inp = torch.randn(0, 16, 20, 32, device=device).contiguous(
812*da0073e9SAndroid Build Coastguard Worker            memory_format=torch.channels_last
813*da0073e9SAndroid Build Coastguard Worker        )
814*da0073e9SAndroid Build Coastguard Worker        _test_module_empty_input(self, avgpool, clast_inp, check_size=False)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker        # test with empty non-batch input
817*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "3D or 4D"):
818*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(16, 0, 20, 32, device=device)
819*da0073e9SAndroid Build Coastguard Worker            avgpool(inp)
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker    def test_pooling_shape(self, device):
822*da0073e9SAndroid Build Coastguard Worker        """Test the output shape calculation for pooling functions"""
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker        # Checks output shape against expected for 1D, 2D and 3D
825*da0073e9SAndroid Build Coastguard Worker        def check(expected_out_shape, sizes, *args, **kwargs):
826*da0073e9SAndroid Build Coastguard Worker            for kernel in ["max", "avg"]:
827*da0073e9SAndroid Build Coastguard Worker                for i in [1, 2, 3]:
828*da0073e9SAndroid Build Coastguard Worker                    if hasattr(torch.nn.functional, f"{kernel}_pool{i}d"):
829*da0073e9SAndroid Build Coastguard Worker                        op = getattr(torch.nn.functional, f"{kernel}_pool{i}d")
830*da0073e9SAndroid Build Coastguard Worker                        t = torch.randn(sizes[: i + 2], device=device)
831*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
832*da0073e9SAndroid Build Coastguard Worker                            op(t, *args, **kwargs).shape, expected_out_shape[: i + 2]
833*da0073e9SAndroid Build Coastguard Worker                        )
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker        check(
836*da0073e9SAndroid Build Coastguard Worker            (1, 1, 3, 3, 4),
837*da0073e9SAndroid Build Coastguard Worker            (1, 1, 5, 6, 7),
838*da0073e9SAndroid Build Coastguard Worker            kernel_size=1,
839*da0073e9SAndroid Build Coastguard Worker            stride=2,
840*da0073e9SAndroid Build Coastguard Worker            padding=0,
841*da0073e9SAndroid Build Coastguard Worker            ceil_mode=True,
842*da0073e9SAndroid Build Coastguard Worker        )
843*da0073e9SAndroid Build Coastguard Worker        check(
844*da0073e9SAndroid Build Coastguard Worker            (1, 1, 2, 3, 3),
845*da0073e9SAndroid Build Coastguard Worker            (1, 1, 3, 4, 5),
846*da0073e9SAndroid Build Coastguard Worker            kernel_size=2,
847*da0073e9SAndroid Build Coastguard Worker            stride=2,
848*da0073e9SAndroid Build Coastguard Worker            padding=1,
849*da0073e9SAndroid Build Coastguard Worker            ceil_mode=False,
850*da0073e9SAndroid Build Coastguard Worker        )
851*da0073e9SAndroid Build Coastguard Worker        check(
852*da0073e9SAndroid Build Coastguard Worker            (1, 1, 2, 3, 3),
853*da0073e9SAndroid Build Coastguard Worker            (1, 1, 3, 4, 5),
854*da0073e9SAndroid Build Coastguard Worker            kernel_size=2,
855*da0073e9SAndroid Build Coastguard Worker            stride=2,
856*da0073e9SAndroid Build Coastguard Worker            padding=1,
857*da0073e9SAndroid Build Coastguard Worker            ceil_mode=True,
858*da0073e9SAndroid Build Coastguard Worker        )
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        # Test case from issue https://github.com/pytorch/pytorch/issues/45357
861*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, 6, 7, device=device)
862*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.max_pool2d(
863*da0073e9SAndroid Build Coastguard Worker            x, 1, stride=(2, 2), padding=0, ceil_mode=True
864*da0073e9SAndroid Build Coastguard Worker        )
865*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.size(), (1, 1, 3, 4))
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes  # TODO: fix on XLA
868*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pool2d_output_size_one(self, device):
869*da0073e9SAndroid Build Coastguard Worker        def helper(size, memory_format):
870*da0073e9SAndroid Build Coastguard Worker            x = torch.randint(
871*da0073e9SAndroid Build Coastguard Worker                1, 10, size, dtype=torch.float, device=device, requires_grad=True
872*da0073e9SAndroid Build Coastguard Worker            )
873*da0073e9SAndroid Build Coastguard Worker            if memory_format == "non_contiguous":
874*da0073e9SAndroid Build Coastguard Worker                x = x[::2, ::2, ::2, ::2]
875*da0073e9SAndroid Build Coastguard Worker            else:
876*da0073e9SAndroid Build Coastguard Worker                x = x.to(memory_format=memory_format)
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker            net = torch.nn.AdaptiveAvgPool2d((1, 1))
879*da0073e9SAndroid Build Coastguard Worker            out = net(x)
880*da0073e9SAndroid Build Coastguard Worker            ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()  # make sure it doesn't crash
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
885*da0073e9SAndroid Build Coastguard Worker            if memory_format == torch.channels_last:
886*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
887*da0073e9SAndroid Build Coastguard Worker                c = out.size(1)
888*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.stride(), [c, 1, c, c])
889*da0073e9SAndroid Build Coastguard Worker            else:
890*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_contiguous())
891*da0073e9SAndroid Build Coastguard Worker                c = out.size(1)
892*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.stride(), [c, 1, 1, 1])
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker        for mf in (torch.contiguous_format, torch.channels_last, "non_contiguous"):
895*da0073e9SAndroid Build Coastguard Worker            helper((2, 3, 6, 6), mf)
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
898*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_avg_pool3d_output_size_one(self, device):
899*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(
900*da0073e9SAndroid Build Coastguard Worker            (2, 3, 6, 6, 6), dtype=torch.float, device=device, requires_grad=True
901*da0073e9SAndroid Build Coastguard Worker        )
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker        net = torch.nn.AdaptiveAvgPool3d(1)
904*da0073e9SAndroid Build Coastguard Worker        out = net(x)
905*da0073e9SAndroid Build Coastguard Worker        ref_out = x.contiguous().mean((-1, -2, -3)).view(out.shape)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()  # make sure it doesn't crash
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ref_out)
910*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.is_contiguous())
911*da0073e9SAndroid Build Coastguard Worker        c = out.size(1)
912*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.stride(), [c, 1, 1, 1, 1])
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # Runtime Error not raised for meta
915*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
916*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long)
917*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_no_suppot_input(self, device, dtype):
918*da0073e9SAndroid Build Coastguard Worker        for numel in (2, 3):
919*da0073e9SAndroid Build Coastguard Worker            for pool_type in ("Max", "Avg"):
920*da0073e9SAndroid Build Coastguard Worker                cls_name = f"Adaptive{pool_type}Pool{numel}d"
921*da0073e9SAndroid Build Coastguard Worker                module_cls = getattr(nn, cls_name)
922*da0073e9SAndroid Build Coastguard Worker                output_size = (2,) * numel
923*da0073e9SAndroid Build Coastguard Worker                module = module_cls(output_size)
924*da0073e9SAndroid Build Coastguard Worker                input = torch.randn((4,) * (numel + 1), device=device).to(dtype)
925*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "not implemented"):
926*da0073e9SAndroid Build Coastguard Worker                    output = module(input)
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
929*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
930*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
931*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
932*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_nhwc(self, device, dtype):
933*da0073e9SAndroid Build Coastguard Worker        def helper(
934*da0073e9SAndroid Build Coastguard Worker            n,
935*da0073e9SAndroid Build Coastguard Worker            c,
936*da0073e9SAndroid Build Coastguard Worker            h,
937*da0073e9SAndroid Build Coastguard Worker            w,
938*da0073e9SAndroid Build Coastguard Worker            kernel_size,
939*da0073e9SAndroid Build Coastguard Worker            stride=None,
940*da0073e9SAndroid Build Coastguard Worker            count_include_pad=True,
941*da0073e9SAndroid Build Coastguard Worker            divisor_override=None,
942*da0073e9SAndroid Build Coastguard Worker            padding=0,
943*da0073e9SAndroid Build Coastguard Worker        ):
944*da0073e9SAndroid Build Coastguard Worker            if stride is None:
945*da0073e9SAndroid Build Coastguard Worker                stride = kernel_size
946*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(n, c, h, w, dtype=dtype, device=device)
947*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
948*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn(
949*da0073e9SAndroid Build Coastguard Worker                n,
950*da0073e9SAndroid Build Coastguard Worker                c,
951*da0073e9SAndroid Build Coastguard Worker                (h - kernel_size) // stride + 1,
952*da0073e9SAndroid Build Coastguard Worker                (w - kernel_size) // stride + 1,
953*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
954*da0073e9SAndroid Build Coastguard Worker                device=device,
955*da0073e9SAndroid Build Coastguard Worker            )
956*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.AvgPool2d(
957*da0073e9SAndroid Build Coastguard Worker                kernel_size,
958*da0073e9SAndroid Build Coastguard Worker                stride=stride,
959*da0073e9SAndroid Build Coastguard Worker                count_include_pad=count_include_pad,
960*da0073e9SAndroid Build Coastguard Worker                divisor_override=divisor_override,
961*da0073e9SAndroid Build Coastguard Worker            ).to(device)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
964*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
965*da0073e9SAndroid Build Coastguard Worker            ref_pool = torch.nn.AvgPool2d(
966*da0073e9SAndroid Build Coastguard Worker                kernel_size,
967*da0073e9SAndroid Build Coastguard Worker                stride=stride,
968*da0073e9SAndroid Build Coastguard Worker                count_include_pad=count_include_pad,
969*da0073e9SAndroid Build Coastguard Worker                divisor_override=divisor_override,
970*da0073e9SAndroid Build Coastguard Worker            ).to(device)
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker            out = pool(input)
973*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
974*da0073e9SAndroid Build Coastguard Worker            ref_out = ref_pool(ref_input)
975*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
978*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
979*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
980*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 3)
983*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 3, count_include_pad=False, padding=1)
984*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 3, count_include_pad=False, padding=2, stride=2)
985*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 3, divisor_override=42)
986*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 7)
987*da0073e9SAndroid Build Coastguard Worker        # ROCm 16GB MI25 hits OOM error. Clear caching allocator prior to running large subtest.
988*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_ROCM and "cuda" in device:
989*da0073e9SAndroid Build Coastguard Worker            torch.cuda.empty_cache()
990*da0073e9SAndroid Build Coastguard Worker        helper(200, 512, 28, 28, 2)
991*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 7, 7, 3, stride=1)
992*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 7, 7, 3, padding=2, stride=1)
993*da0073e9SAndroid Build Coastguard Worker        helper(10, 512, 31, 31, 3, stride=2)
994*da0073e9SAndroid Build Coastguard Worker        helper(1, 129, 8, 8, 3, stride=2)
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
997*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
998*da0073e9SAndroid Build Coastguard Worker    def test_max_pool1d_corner_cases(self, device, dtype):
999*da0073e9SAndroid Build Coastguard Worker        def check(x, args, expected):
1000*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.MaxPool1d(*args)
1001*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, list):
1002*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor(x, device=device, dtype=dtype)
1003*da0073e9SAndroid Build Coastguard Worker                expected = torch.tensor(expected, device=device, dtype=dtype)
1004*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(model(x), expected)
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
1007*da0073e9SAndroid Build Coastguard Worker        check([[1]], (1, None, 0, 1, False, False), [[1]])
1008*da0073e9SAndroid Build Coastguard Worker        check([[1]], (2, None, 1, 2, False, False), [[float("-inf")]])
1009*da0073e9SAndroid Build Coastguard Worker        check(
1010*da0073e9SAndroid Build Coastguard Worker            [[1], [1]],
1011*da0073e9SAndroid Build Coastguard Worker            (2, None, 1, 2, False, False),
1012*da0073e9SAndroid Build Coastguard Worker            [[float("-inf")], [float("-inf")]],
1013*da0073e9SAndroid Build Coastguard Worker        )
1014*da0073e9SAndroid Build Coastguard Worker        check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]])
1015*da0073e9SAndroid Build Coastguard Worker        check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]])
1016*da0073e9SAndroid Build Coastguard Worker
1017*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1018*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1019*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("OOMs https://github.com/pytorch/pytorch/issues/111320")
1020*da0073e9SAndroid Build Coastguard Worker    def test_max_pool1d(self, device, dtype):
1021*da0073e9SAndroid Build Coastguard Worker        # FIXME For now compare against max_pool1d with indices
1022*da0073e9SAndroid Build Coastguard Worker        def check(x, *args, **kwargs):
1023*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.MaxPool1d(*args, **kwargs)
1024*da0073e9SAndroid Build Coastguard Worker            ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True)
1025*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(model(x), ref_model(x)[0])
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        sizes = [random.sample(range(8, 128), 3) for _ in range(3)]
1028*da0073e9SAndroid Build Coastguard Worker        kernel_sizes = random.sample(range(1, 5), 3)
1029*da0073e9SAndroid Build Coastguard Worker        strides = random.sample(range(1, 5), 3)
1030*da0073e9SAndroid Build Coastguard Worker        dilations = random.sample(range(1, 5), 3)
1031*da0073e9SAndroid Build Coastguard Worker        ceil_modes = [True, False]
1032*da0073e9SAndroid Build Coastguard Worker
1033*da0073e9SAndroid Build Coastguard Worker        for size, kernel_size, stride, dilation, ceil_mode in itertools.product(
1034*da0073e9SAndroid Build Coastguard Worker            sizes, kernel_sizes, strides, dilations, ceil_modes
1035*da0073e9SAndroid Build Coastguard Worker        ):
1036*da0073e9SAndroid Build Coastguard Worker            padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1)
1037*da0073e9SAndroid Build Coastguard Worker            check(
1038*da0073e9SAndroid Build Coastguard Worker                torch.randn(size, device=device, dtype=dtype),
1039*da0073e9SAndroid Build Coastguard Worker                kernel_size,
1040*da0073e9SAndroid Build Coastguard Worker                stride,
1041*da0073e9SAndroid Build Coastguard Worker                padding,
1042*da0073e9SAndroid Build Coastguard Worker                dilation,
1043*da0073e9SAndroid Build Coastguard Worker                ceil_mode=ceil_mode,
1044*da0073e9SAndroid Build Coastguard Worker            )
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker        # Non-contiguous test
1047*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(5, 151, 33, device=device, dtype=dtype)[::2, ::3, ::2]
1048*da0073e9SAndroid Build Coastguard Worker        check(tensor, 3, 2, 1, 2, ceil_mode=True)
1049*da0073e9SAndroid Build Coastguard Worker        check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1052*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
1053*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d(self, device):
1054*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, ks):
1055*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(
1056*da0073e9SAndroid Build Coastguard Worker                n, c, h, w, device="cuda", dtype=torch.float, requires_grad=True
1057*da0073e9SAndroid Build Coastguard Worker            )
1058*da0073e9SAndroid Build Coastguard Worker            ref_x = x.detach().clone().cpu().requires_grad_()
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.MaxPool2d(kernel_size=ks)
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker            y = pool(x)
1063*da0073e9SAndroid Build Coastguard Worker            ref_y = pool(ref_x)
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
1066*da0073e9SAndroid Build Coastguard Worker            ref_y.sum().backward()
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
1069*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, ref_x.grad)
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, ks=2)
1072*da0073e9SAndroid Build Coastguard Worker        helper(1, 100000, 32, 32, ks=4)
1073*da0073e9SAndroid Build Coastguard Worker        helper(1, 100000, 1, 4, ks=(1, 4))  # test for max_pool1d
1074*da0073e9SAndroid Build Coastguard Worker
1075*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1076*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1077*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1078*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
1079*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d_nhwc(self, device, dtype):
1080*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, kernel_size, stride=None):
1081*da0073e9SAndroid Build Coastguard Worker            if stride is None:
1082*da0073e9SAndroid Build Coastguard Worker                stride = kernel_size
1083*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(n, c, h, w, dtype=dtype, device=device)
1084*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
1085*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn(
1086*da0073e9SAndroid Build Coastguard Worker                n,
1087*da0073e9SAndroid Build Coastguard Worker                c,
1088*da0073e9SAndroid Build Coastguard Worker                (h - kernel_size) // stride + 1,
1089*da0073e9SAndroid Build Coastguard Worker                (w - kernel_size) // stride + 1,
1090*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
1091*da0073e9SAndroid Build Coastguard Worker                device=device,
1092*da0073e9SAndroid Build Coastguard Worker            )
1093*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(
1094*da0073e9SAndroid Build Coastguard Worker                device
1095*da0073e9SAndroid Build Coastguard Worker            )
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1098*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
1099*da0073e9SAndroid Build Coastguard Worker            ref_pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(
1100*da0073e9SAndroid Build Coastguard Worker                device
1101*da0073e9SAndroid Build Coastguard Worker            )
1102*da0073e9SAndroid Build Coastguard Worker
1103*da0073e9SAndroid Build Coastguard Worker            out, ind = pool(input)
1104*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
1105*da0073e9SAndroid Build Coastguard Worker            ref_out, ref_ind = ref_pool(ref_input)
1106*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
1109*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
1110*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
1111*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_ind.is_contiguous())
1112*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
1113*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ind, ref_ind)
1114*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 7)
1117*da0073e9SAndroid Build Coastguard Worker        helper(200, 512, 28, 28, 2)
1118*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 7, 7, 3, stride=1)
1119*da0073e9SAndroid Build Coastguard Worker        helper(10, 512, 31, 31, 3, stride=2)
1120*da0073e9SAndroid Build Coastguard Worker        helper(1, 129, 8, 8, 3, stride=2)
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1123*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int32, torch.int64)
1124*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d_corner_cases(self, device, dtype):
1125*da0073e9SAndroid Build Coastguard Worker        def check(x, args, expected, memory_format):
1126*da0073e9SAndroid Build Coastguard Worker            model = torch.nn.MaxPool2d(*args)
1127*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, list):
1128*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor(x, device=device, dtype=dtype).to(
1129*da0073e9SAndroid Build Coastguard Worker                    memory_format=memory_format
1130*da0073e9SAndroid Build Coastguard Worker                )
1131*da0073e9SAndroid Build Coastguard Worker                expected = torch.tensor(expected, device=device, dtype=dtype).to(
1132*da0073e9SAndroid Build Coastguard Worker                    memory_format=memory_format
1133*da0073e9SAndroid Build Coastguard Worker                )
1134*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(model(x), expected)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker        # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
1137*da0073e9SAndroid Build Coastguard Worker        check(
1138*da0073e9SAndroid Build Coastguard Worker            [[[[-1, -2], [-3, -4]]]],
1139*da0073e9SAndroid Build Coastguard Worker            (2, 2, 1, 2, False, True),
1140*da0073e9SAndroid Build Coastguard Worker            [[[[-4, -4], [-4, -4]]]],
1141*da0073e9SAndroid Build Coastguard Worker            torch.contiguous_format,
1142*da0073e9SAndroid Build Coastguard Worker        )
1143*da0073e9SAndroid Build Coastguard Worker        check(
1144*da0073e9SAndroid Build Coastguard Worker            [[[[-1, -2], [-3, -4]]]],
1145*da0073e9SAndroid Build Coastguard Worker            (2, 2, 1, 2, False, True),
1146*da0073e9SAndroid Build Coastguard Worker            [[[[-4, -4], [-4, -4]]]],
1147*da0073e9SAndroid Build Coastguard Worker            torch.channels_last,
1148*da0073e9SAndroid Build Coastguard Worker        )
1149*da0073e9SAndroid Build Coastguard Worker
1150*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1151*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1152*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1153*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
1154*da0073e9SAndroid Build Coastguard Worker    def test_max_pool3d_ndhwc(self, device, dtype):
1155*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, d, kernel_size, stride=None):
1156*da0073e9SAndroid Build Coastguard Worker            batch = n
1157*da0073e9SAndroid Build Coastguard Worker            if not batch:
1158*da0073e9SAndroid Build Coastguard Worker                batch = 1
1159*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(batch, c, d, h, w, dtype=dtype, device=device)
1160*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(
1161*da0073e9SAndroid Build Coastguard Worker                memory_format=torch.channels_last_3d
1162*da0073e9SAndroid Build Coastguard Worker            ).requires_grad_()
1163*da0073e9SAndroid Build Coastguard Worker            if not n:
1164*da0073e9SAndroid Build Coastguard Worker                input = input.squeeze(0).detach().clone().requires_grad_()
1165*da0073e9SAndroid Build Coastguard Worker            if isinstance(kernel_size, int):
1166*da0073e9SAndroid Build Coastguard Worker                kernel_size = [kernel_size] * 3
1167*da0073e9SAndroid Build Coastguard Worker            if stride is None:
1168*da0073e9SAndroid Build Coastguard Worker                stride = kernel_size
1169*da0073e9SAndroid Build Coastguard Worker            elif isinstance(stride, int):
1170*da0073e9SAndroid Build Coastguard Worker                stride = [stride] * 3
1171*da0073e9SAndroid Build Coastguard Worker            grad = torch.randn(
1172*da0073e9SAndroid Build Coastguard Worker                batch,
1173*da0073e9SAndroid Build Coastguard Worker                c,
1174*da0073e9SAndroid Build Coastguard Worker                (d - kernel_size[0]) // stride[0] + 1,
1175*da0073e9SAndroid Build Coastguard Worker                (h - kernel_size[1]) // stride[1] + 1,
1176*da0073e9SAndroid Build Coastguard Worker                (w - kernel_size[2]) // stride[2] + 1,
1177*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
1178*da0073e9SAndroid Build Coastguard Worker                device=device,
1179*da0073e9SAndroid Build Coastguard Worker            )
1180*da0073e9SAndroid Build Coastguard Worker            grad = grad.contiguous(memory_format=torch.channels_last_3d)
1181*da0073e9SAndroid Build Coastguard Worker            if not n:
1182*da0073e9SAndroid Build Coastguard Worker                grad = grad.squeeze(0)
1183*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(
1184*da0073e9SAndroid Build Coastguard Worker                device
1185*da0073e9SAndroid Build Coastguard Worker            )
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1188*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
1189*da0073e9SAndroid Build Coastguard Worker            ref_pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(
1190*da0073e9SAndroid Build Coastguard Worker                device
1191*da0073e9SAndroid Build Coastguard Worker            )
1192*da0073e9SAndroid Build Coastguard Worker            out, ind = pool(input)
1193*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
1194*da0073e9SAndroid Build Coastguard Worker            ref_out, ref_ind = ref_pool(ref_input)
1195*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker            if len(out.shape) == 4:
1198*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1199*da0073e9SAndroid Build Coastguard Worker                    out.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d)
1200*da0073e9SAndroid Build Coastguard Worker                )
1201*da0073e9SAndroid Build Coastguard Worker            else:
1202*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
1203*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
1204*da0073e9SAndroid Build Coastguard Worker            if len(ind.shape) == 4:
1205*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1206*da0073e9SAndroid Build Coastguard Worker                    ind.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d)
1207*da0073e9SAndroid Build Coastguard Worker                )
1208*da0073e9SAndroid Build Coastguard Worker            else:
1209*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last_3d))
1210*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_ind.is_contiguous())
1211*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
1212*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ind, ref_ind)
1213*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.half:
1214*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input.grad, ref_input.grad, atol=0.05, rtol=0.01)
1215*da0073e9SAndroid Build Coastguard Worker            else:
1216*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(input.grad, ref_input.grad)
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 8, 7)
1219*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 8, 8, 8, (5, 6, 7))
1220*da0073e9SAndroid Build Coastguard Worker        helper(1, 8, 8, 8, 8, (5, 6, 7))
1221*da0073e9SAndroid Build Coastguard Worker        helper(0, 6, 12, 13, 14, (5, 6, 7))
1222*da0073e9SAndroid Build Coastguard Worker        helper(4, 8, 7, 7, 7, 3, stride=1)
1223*da0073e9SAndroid Build Coastguard Worker        helper(10, 128, 19, 19, 19, 3, stride=2)
1224*da0073e9SAndroid Build Coastguard Worker        helper(10, 128, 19, 19, 19, (1, 2, 3), stride=2)
1225*da0073e9SAndroid Build Coastguard Worker        helper(1, 128, 19, 19, 19, (1, 2, 3), stride=2)
1226*da0073e9SAndroid Build Coastguard Worker        helper(0, 128, 19, 19, 19, (1, 2, 3), stride=2)
1227*da0073e9SAndroid Build Coastguard Worker        helper(1, 79, 4, 4, 4, 3, stride=2)
1228*da0073e9SAndroid Build Coastguard Worker        helper(0, 79, 4, 4, 4, 3, stride=2)
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1231*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.bfloat16)
1232*da0073e9SAndroid Build Coastguard Worker    def test_max_pool_bfloat16_half(self, device, dtype):
1233*da0073e9SAndroid Build Coastguard Worker        def helper(shape, kernel_size, stride, memory_format, dtype):
1234*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(shape, dtype=dtype, device=device)
1235*da0073e9SAndroid Build Coastguard Worker            input = input.to(memory_format=memory_format).requires_grad_()
1236*da0073e9SAndroid Build Coastguard Worker            if len(shape) == 4:
1237*da0073e9SAndroid Build Coastguard Worker                pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(
1238*da0073e9SAndroid Build Coastguard Worker                    device
1239*da0073e9SAndroid Build Coastguard Worker                )
1240*da0073e9SAndroid Build Coastguard Worker            else:
1241*da0073e9SAndroid Build Coastguard Worker                pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(
1242*da0073e9SAndroid Build Coastguard Worker                    device
1243*da0073e9SAndroid Build Coastguard Worker                )
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().float().requires_grad_(True)
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker            out, ind = pool(input)
1248*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
1249*da0073e9SAndroid Build Coastguard Worker            out2, ind2 = pool(input2)
1250*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=memory_format))
1253*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.dtype, dtype)
1254*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.dtype, dtype)
1255*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2.to(dtype=dtype))
1256*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ind, ind2)
1257*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
1260*da0073e9SAndroid Build Coastguard Worker        helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
1261*da0073e9SAndroid Build Coastguard Worker        helper((1, 19, 20, 10), 8, 2, torch.contiguous_format, dtype)
1262*da0073e9SAndroid Build Coastguard Worker        helper((1, 19, 20, 10), 8, 2, torch.channels_last, dtype)
1263*da0073e9SAndroid Build Coastguard Worker        helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype)
1264*da0073e9SAndroid Build Coastguard Worker        helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype)
1265*da0073e9SAndroid Build Coastguard Worker        helper((1, 19, 10, 10, 10), 8, 2, torch.contiguous_format, dtype)
1266*da0073e9SAndroid Build Coastguard Worker        helper((1, 19, 10, 9, 14), 8, 2, torch.channels_last_3d, dtype)
1267*da0073e9SAndroid Build Coastguard Worker        helper((4, 10, 3, 8, 8), 3, 1, torch.contiguous_format, dtype)
1268*da0073e9SAndroid Build Coastguard Worker        helper((4, 10, 8, 8, 8), 7, 1, torch.channels_last_3d, dtype)
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1271*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
1272*da0073e9SAndroid Build Coastguard Worker    def test_max_pool2d_indices(self, device):
1273*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, ks):
1274*da0073e9SAndroid Build Coastguard Worker            if n is None:
1275*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(
1276*da0073e9SAndroid Build Coastguard Worker                    c, h, w, device="cuda", dtype=torch.float, requires_grad=True
1277*da0073e9SAndroid Build Coastguard Worker                )
1278*da0073e9SAndroid Build Coastguard Worker            else:
1279*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(
1280*da0073e9SAndroid Build Coastguard Worker                    n, c, h, w, device="cuda", dtype=torch.float, requires_grad=True
1281*da0073e9SAndroid Build Coastguard Worker                )
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker            ref_x = x.detach().clone().cpu().requires_grad_()
1284*da0073e9SAndroid Build Coastguard Worker
1285*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.MaxPool2d(kernel_size=ks, return_indices=True)
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker            y, idx = pool(x)
1288*da0073e9SAndroid Build Coastguard Worker            ref_y, ref_idx = pool(ref_x)
1289*da0073e9SAndroid Build Coastguard Worker
1290*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
1291*da0073e9SAndroid Build Coastguard Worker            ref_y.sum().backward()
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, ref_y)
1294*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1295*da0073e9SAndroid Build Coastguard Worker                idx, ref_idx
1296*da0073e9SAndroid Build Coastguard Worker            )  # assertEqual implicitly compares shape for tensors
1297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, ref_x.grad)
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker        helper(2, 8, 4, 4, ks=2)
1300*da0073e9SAndroid Build Coastguard Worker        helper(None, 3, 50, 50, ks=5)
1301*da0073e9SAndroid Build Coastguard Worker
1302*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1303*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.bfloat16)
1304*da0073e9SAndroid Build Coastguard Worker    def test_avg_pool2d_reduced_floating(self, device, dtype):
1305*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, kernel_size, stride, memory_format):
1306*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(n, c, h, w, dtype=torch.float32, device=device).to(
1307*da0073e9SAndroid Build Coastguard Worker                dtype=dtype
1308*da0073e9SAndroid Build Coastguard Worker            )
1309*da0073e9SAndroid Build Coastguard Worker            input = input.to(memory_format=memory_format).requires_grad_()
1310*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.AvgPool2d(kernel_size, stride).to(device)
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker            input2 = input.detach().clone().float().requires_grad_(True)
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker            out = pool(input)
1315*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
1316*da0073e9SAndroid Build Coastguard Worker            out2 = pool(input2)
1317*da0073e9SAndroid Build Coastguard Worker            out2.sum().backward()
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=memory_format))
1320*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.dtype, dtype)
1321*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.dtype, dtype)
1322*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, out2.to(dtype=dtype))
1323*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, input2.grad.to(dtype=dtype))
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker        helper(4, 30, 8, 8, 7, 1, torch.contiguous_format)
1326*da0073e9SAndroid Build Coastguard Worker        helper(4, 65, 8, 8, 7, 1, torch.channels_last)
1327*da0073e9SAndroid Build Coastguard Worker        helper(1, 19, 20, 10, 8, 2, torch.contiguous_format)
1328*da0073e9SAndroid Build Coastguard Worker        helper(1, 19, 20, 10, 8, 2, torch.channels_last)
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1331*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pooling_max_nhwc(self, device, dtype):
1332*da0073e9SAndroid Build Coastguard Worker        def helper(input_size, output_plane_size, contig):
1333*da0073e9SAndroid Build Coastguard Worker            n_plane_dims = len(output_plane_size)
1334*da0073e9SAndroid Build Coastguard Worker            mod = (
1335*da0073e9SAndroid Build Coastguard Worker                torch.nn.AdaptiveMaxPool2d
1336*da0073e9SAndroid Build Coastguard Worker                if n_plane_dims == 2
1337*da0073e9SAndroid Build Coastguard Worker                else torch.nn.AdaptiveMaxPool3d
1338*da0073e9SAndroid Build Coastguard Worker            )
1339*da0073e9SAndroid Build Coastguard Worker            channels_last = (
1340*da0073e9SAndroid Build Coastguard Worker                torch.channels_last if n_plane_dims == 2 else torch.channels_last_3d
1341*da0073e9SAndroid Build Coastguard Worker            )
1342*da0073e9SAndroid Build Coastguard Worker            output_size = input_size[:2] + output_plane_size
1343*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(1, 10, input_size, device=device, dtype=dtype)
1344*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=channels_last)
1345*da0073e9SAndroid Build Coastguard Worker            grad = torch.randint(1, 10, output_size, device=device, dtype=dtype)
1346*da0073e9SAndroid Build Coastguard Worker            grad = grad.contiguous(memory_format=channels_last)
1347*da0073e9SAndroid Build Coastguard Worker            if not contig:
1348*da0073e9SAndroid Build Coastguard Worker                input = input[:, ::2]
1349*da0073e9SAndroid Build Coastguard Worker                grad = grad[:, ::2]
1350*da0073e9SAndroid Build Coastguard Worker            input.requires_grad_(True)
1351*da0073e9SAndroid Build Coastguard Worker            pool = mod(output_plane_size, return_indices=True).to(device)
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1354*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
1355*da0073e9SAndroid Build Coastguard Worker            ref_pool = mod(output_plane_size, return_indices=True).to(device)
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker            out, ind = pool(input)
1358*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
1359*da0073e9SAndroid Build Coastguard Worker            ref_out, ref_ind = ref_pool(ref_input)
1360*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker            # channels_last_3d case does not return channels_last_3d outputs
1363*da0073e9SAndroid Build Coastguard Worker            if n_plane_dims == 2:
1364*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_contiguous(memory_format=channels_last))
1365*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(ind.is_contiguous(memory_format=channels_last))
1366*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
1367*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_ind.is_contiguous())
1368*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
1369*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ind, ref_ind)
1370*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker        for contig in [True, False]:
1373*da0073e9SAndroid Build Coastguard Worker            helper((4, 8, 10, 10), (7, 7), contig)
1374*da0073e9SAndroid Build Coastguard Worker            helper((4, 8, 9, 14), (5, 8), contig)
1375*da0073e9SAndroid Build Coastguard Worker            helper((4, 8, 11, 11), (1, 1), contig)
1376*da0073e9SAndroid Build Coastguard Worker            helper((2, 1, 3, 3), (1, 1), contig)
1377*da0073e9SAndroid Build Coastguard Worker            helper((4, 8, 10, 10, 10), (7, 7, 7), contig)
1378*da0073e9SAndroid Build Coastguard Worker            helper((4, 8, 11, 11, 11), (1, 1, 1), contig)
1379*da0073e9SAndroid Build Coastguard Worker            helper((2, 1, 3, 3, 3), (1, 1, 1), contig)
1380*da0073e9SAndroid Build Coastguard Worker
1381*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1382*da0073e9SAndroid Build Coastguard Worker    def test_pooling_max_nhwc(self, device, dtype):
1383*da0073e9SAndroid Build Coastguard Worker        def helper(n, c, h, w, kernel_size, stride, padding, dilation, contig, device):
1384*da0073e9SAndroid Build Coastguard Worker            output_height = math.floor(
1385*da0073e9SAndroid Build Coastguard Worker                (h + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1)
1386*da0073e9SAndroid Build Coastguard Worker                / stride[0]
1387*da0073e9SAndroid Build Coastguard Worker                + 1
1388*da0073e9SAndroid Build Coastguard Worker            )
1389*da0073e9SAndroid Build Coastguard Worker            output_width = math.floor(
1390*da0073e9SAndroid Build Coastguard Worker                (w + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1)
1391*da0073e9SAndroid Build Coastguard Worker                / stride[1]
1392*da0073e9SAndroid Build Coastguard Worker                + 1
1393*da0073e9SAndroid Build Coastguard Worker            )
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker            input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
1396*da0073e9SAndroid Build Coastguard Worker            input = input.contiguous(memory_format=torch.channels_last)
1397*da0073e9SAndroid Build Coastguard Worker            grad = torch.randint(
1398*da0073e9SAndroid Build Coastguard Worker                1, 10, (n, c, output_height, output_width), device=device, dtype=dtype
1399*da0073e9SAndroid Build Coastguard Worker            )
1400*da0073e9SAndroid Build Coastguard Worker            grad = grad.contiguous(memory_format=torch.channels_last)
1401*da0073e9SAndroid Build Coastguard Worker            if not contig:
1402*da0073e9SAndroid Build Coastguard Worker                input = input[:, ::2, :, :]
1403*da0073e9SAndroid Build Coastguard Worker                grad = grad[:, ::2, :, :]
1404*da0073e9SAndroid Build Coastguard Worker            input.requires_grad_(True)
1405*da0073e9SAndroid Build Coastguard Worker            pool = torch.nn.MaxPool2d(
1406*da0073e9SAndroid Build Coastguard Worker                kernel_size,
1407*da0073e9SAndroid Build Coastguard Worker                stride,
1408*da0073e9SAndroid Build Coastguard Worker                padding,
1409*da0073e9SAndroid Build Coastguard Worker                dilation,
1410*da0073e9SAndroid Build Coastguard Worker                return_indices=True,
1411*da0073e9SAndroid Build Coastguard Worker                ceil_mode=False,
1412*da0073e9SAndroid Build Coastguard Worker            )
1413*da0073e9SAndroid Build Coastguard Worker
1414*da0073e9SAndroid Build Coastguard Worker            ref_input = input.detach().clone().contiguous().requires_grad_(True)
1415*da0073e9SAndroid Build Coastguard Worker            ref_grad = grad.detach().clone().contiguous()
1416*da0073e9SAndroid Build Coastguard Worker            ref_pool = torch.nn.MaxPool2d(
1417*da0073e9SAndroid Build Coastguard Worker                kernel_size,
1418*da0073e9SAndroid Build Coastguard Worker                stride,
1419*da0073e9SAndroid Build Coastguard Worker                padding,
1420*da0073e9SAndroid Build Coastguard Worker                dilation,
1421*da0073e9SAndroid Build Coastguard Worker                return_indices=True,
1422*da0073e9SAndroid Build Coastguard Worker                ceil_mode=False,
1423*da0073e9SAndroid Build Coastguard Worker            ).to(device)
1424*da0073e9SAndroid Build Coastguard Worker
1425*da0073e9SAndroid Build Coastguard Worker            out, ind = pool(input)
1426*da0073e9SAndroid Build Coastguard Worker            out.backward(grad)
1427*da0073e9SAndroid Build Coastguard Worker            ref_out, ref_ind = ref_pool(ref_input)
1428*da0073e9SAndroid Build Coastguard Worker            ref_out.backward(ref_grad)
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
1431*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_out.is_contiguous())
1432*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
1433*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(ref_ind.is_contiguous())
1434*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out, ref_out)
1435*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ind, ref_ind)
1436*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad, ref_input.grad)
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker        for contig in [True, False]:
1439*da0073e9SAndroid Build Coastguard Worker            helper(4, 8, 10, 10, (2, 2), (1, 1), (1, 1), (2, 2), contig, device)
1440*da0073e9SAndroid Build Coastguard Worker            helper(4, 8, 9, 14, (2, 2), (1, 1), (1, 1), (2, 2), contig, device)
1441*da0073e9SAndroid Build Coastguard Worker            helper(4, 8, 11, 11, (4, 4), (2, 2), (2, 2), (2, 2), contig, device)
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1444*da0073e9SAndroid Build Coastguard Worker    def test_pool3d_size_one_feature_dim(self, device):
1445*da0073e9SAndroid Build Coastguard Worker        # Tests crazy strides for feature dim of size 1
1446*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(7, 1, 5, 3, 2, device=device)
1447*da0073e9SAndroid Build Coastguard Worker        strange_strides = [30, 1234, 6, 2, 1]
1448*da0073e9SAndroid Build Coastguard Worker        y = x.as_strided(x.size(), strange_strides)
1449*da0073e9SAndroid Build Coastguard Worker        x = x.cpu().as_strided(x.size(), strange_strides)
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker        to_test = {
1452*da0073e9SAndroid Build Coastguard Worker            "max_pool3d": lambda t: F.max_pool3d(t, (5, 1, 1), stride=(5, 1, 1)),
1453*da0073e9SAndroid Build Coastguard Worker            "avg_pool3d": lambda t: F.avg_pool3d(t, (5, 1, 1), stride=(5, 1, 1)),
1454*da0073e9SAndroid Build Coastguard Worker        }
1455*da0073e9SAndroid Build Coastguard Worker
1456*da0073e9SAndroid Build Coastguard Worker        for test, fn in to_test.items():
1457*da0073e9SAndroid Build Coastguard Worker            # Should not crash
1458*da0073e9SAndroid Build Coastguard Worker            out_y = fn(y)
1459*da0073e9SAndroid Build Coastguard Worker            out_x = fn(x)
1460*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_y, out_x.to(device), msg=test)
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1463*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("18GB")
1464*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("180GB", "cpu")
1465*da0073e9SAndroid Build Coastguard Worker    def test_pool3d_large_size_int64(self, device):
1466*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/52822
1467*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(
1468*da0073e9SAndroid Build Coastguard Worker            70, 32, 100, 100, 100, dtype=torch.half, device=device, requires_grad=True
1469*da0073e9SAndroid Build Coastguard Worker        )
1470*da0073e9SAndroid Build Coastguard Worker        y = torch.nn.functional.max_pool3d(x, 5)
1471*da0073e9SAndroid Build Coastguard Worker        g = torch.randn_like(y, dtype=torch.half)
1472*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
1473*da0073e9SAndroid Build Coastguard Worker        y.backward(g)
1474*da0073e9SAndroid Build Coastguard Worker        torch.cuda.synchronize()
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker        ref_x = x.detach().cpu().float()  # max_pool3d_cpu is not implemented for half
1477*da0073e9SAndroid Build Coastguard Worker        ref_x.requires_grad = True
1478*da0073e9SAndroid Build Coastguard Worker        ref_g = g.cpu().float()
1479*da0073e9SAndroid Build Coastguard Worker        ref_y = torch.nn.functional.max_pool3d(ref_x, 5)
1480*da0073e9SAndroid Build Coastguard Worker        ref_y.backward(ref_g)
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, ref_y, exact_dtype=False)
1483*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, ref_x.grad, exact_dtype=False)
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1486*da0073e9SAndroid Build Coastguard Worker    def test_AvgPool3d_backward_after_cat_dim1_device(self, device):
1487*da0073e9SAndroid Build Coastguard Worker        # x has to have batch_size 1 to test contiguous checks
1488*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 3, 4, 4, 4, device=device, requires_grad=True)
1489*da0073e9SAndroid Build Coastguard Worker        y = F.avg_pool3d(x, kernel_size=3, padding=1, stride=2)
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(y.size(), device=device)
1492*da0073e9SAndroid Build Coastguard Worker        # increase the stride in dimension 0. the tensor is still contiguous because size[0] is 1
1493*da0073e9SAndroid Build Coastguard Worker        stride = list(grad.stride())
1494*da0073e9SAndroid Build Coastguard Worker        stride[0] = stride[0] * 2
1495*da0073e9SAndroid Build Coastguard Worker        grad.set_(grad.storage(), 0, grad.size(), stride)
1496*da0073e9SAndroid Build Coastguard Worker        assert grad.is_contiguous()
1497*da0073e9SAndroid Build Coastguard Worker
1498*da0073e9SAndroid Build Coastguard Worker        y.backward(grad)
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker    def _test_maxpool_indices(
1501*da0073e9SAndroid Build Coastguard Worker        self, num_dim, adaptive=False, device="cpu", dtype=torch.float
1502*da0073e9SAndroid Build Coastguard Worker    ):
1503*da0073e9SAndroid Build Coastguard Worker        def expected_indices(dim, dtype):
1504*da0073e9SAndroid Build Coastguard Worker            if dim == 1:
1505*da0073e9SAndroid Build Coastguard Worker                return torch.tensor([1, 3], dtype=dtype).repeat(2, 2, 1)
1506*da0073e9SAndroid Build Coastguard Worker            if dim == 2:
1507*da0073e9SAndroid Build Coastguard Worker                return torch.tensor([[5, 7], [13, 15]], dtype=dtype).repeat(2, 2, 1, 1)
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker        def expected_grad(dim, dtype):
1510*da0073e9SAndroid Build Coastguard Worker            if dim == 1:
1511*da0073e9SAndroid Build Coastguard Worker                return torch.tensor([0, 1, 0, 1], dtype=dtype).repeat(2, 2, 1)
1512*da0073e9SAndroid Build Coastguard Worker            grad = expected_grad(dim - 1, dtype=dtype)
1513*da0073e9SAndroid Build Coastguard Worker            zero = torch.zeros(grad.size(), dtype=dtype)
1514*da0073e9SAndroid Build Coastguard Worker            return torch.stack((zero, grad, zero, grad), 2)
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker        def expected_output(dim, dtype):
1517*da0073e9SAndroid Build Coastguard Worker            if dim == 1:
1518*da0073e9SAndroid Build Coastguard Worker                return torch.arange(2, 17, 2, dtype=dtype).view(2, 2, 2)
1519*da0073e9SAndroid Build Coastguard Worker            if dim == 2:
1520*da0073e9SAndroid Build Coastguard Worker                col = torch.arange(6, 63, 8, dtype=dtype)
1521*da0073e9SAndroid Build Coastguard Worker                return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)
1522*da0073e9SAndroid Build Coastguard Worker
1523*da0073e9SAndroid Build Coastguard Worker        if adaptive:
1524*da0073e9SAndroid Build Coastguard Worker            cls_name = "AdaptiveMaxPool{}d".format(num_dim)  # noqa: UP032
1525*da0073e9SAndroid Build Coastguard Worker        else:
1526*da0073e9SAndroid Build Coastguard Worker            # FIXME(#105716): Test fails when using f-string
1527*da0073e9SAndroid Build Coastguard Worker            cls_name = "MaxPool{}d".format(num_dim)  # noqa: UP032
1528*da0073e9SAndroid Build Coastguard Worker        module_cls = getattr(nn, cls_name)
1529*da0073e9SAndroid Build Coastguard Worker        module = module_cls(2, return_indices=True).to(device, dtype=dtype)
1530*da0073e9SAndroid Build Coastguard Worker        numel = 4 ** (num_dim + 1)
1531*da0073e9SAndroid Build Coastguard Worker        input = (
1532*da0073e9SAndroid Build Coastguard Worker            torch.arange(1, numel + 1)
1533*da0073e9SAndroid Build Coastguard Worker            .view(2, 2, *repeat(4, num_dim))
1534*da0073e9SAndroid Build Coastguard Worker            .to(device, dtype=dtype)
1535*da0073e9SAndroid Build Coastguard Worker        )
1536*da0073e9SAndroid Build Coastguard Worker        input_var = input.clone().detach().requires_grad_()
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        # Check forward
1539*da0073e9SAndroid Build Coastguard Worker        output, indices = module(input_var)
1540*da0073e9SAndroid Build Coastguard Worker        if num_dim != 3:
1541*da0073e9SAndroid Build Coastguard Worker            expected_indices = expected_indices(num_dim, dtype=indices.data.dtype)
1542*da0073e9SAndroid Build Coastguard Worker            expected_output = expected_output(num_dim, dtype=output.data.dtype)
1543*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(indices.dim(), input.dim())
1544*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(indices.data.squeeze(), expected_indices)
1545*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output.data.squeeze(), expected_output)
1546*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(output.requires_grad)
1547*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(indices.requires_grad)
1548*da0073e9SAndroid Build Coastguard Worker
1549*da0073e9SAndroid Build Coastguard Worker        # Make sure backward works
1550*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.ones(output.size(), device=device, dtype=dtype)
1551*da0073e9SAndroid Build Coastguard Worker        output.backward(grad_output, retain_graph=True)
1552*da0073e9SAndroid Build Coastguard Worker        expected_grad = expected_grad(num_dim, dtype=input_var.grad.data.dtype)
1553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input_var.grad.data, expected_grad.view_as(input))
1554*da0073e9SAndroid Build Coastguard Worker
1555*da0073e9SAndroid Build Coastguard Worker        # Make sure backward after changing indices will result in an error
1556*da0073e9SAndroid Build Coastguard Worker        indices.add_(1)
1557*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: output.backward(grad_output))
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker        # Make sure -Infinity is handled correctly
1560*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[[float("-inf")]]])
1561*da0073e9SAndroid Build Coastguard Worker        m = nn.MaxPool1d(kernel_size=1, return_indices=True)
1562*da0073e9SAndroid Build Coastguard Worker        output, indices = m(t)
1563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[0, 0, 0], float("-inf"))
1564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(indices[0, 0, 0], 0)
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[[float("-inf")]]])
1567*da0073e9SAndroid Build Coastguard Worker        m = nn.MaxPool2d(kernel_size=1, return_indices=True)
1568*da0073e9SAndroid Build Coastguard Worker        output, indices = m(t)
1569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[0, 0, 0], float("-inf"))
1570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(indices[0, 0, 0], 0)
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[[[float("-inf")]]]])
1573*da0073e9SAndroid Build Coastguard Worker        m = nn.MaxPool3d(kernel_size=1, return_indices=True)
1574*da0073e9SAndroid Build Coastguard Worker        output, indices = m(t)
1575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output[0, 0, 0, 0], float("-inf"))
1576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(indices[0, 0, 0, 0], 0)
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1579*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1580*da0073e9SAndroid Build Coastguard Worker    def test_MaxPool1d_indices(self, device, dtype):
1581*da0073e9SAndroid Build Coastguard Worker        self._test_maxpool_indices(1, device=device, dtype=dtype)
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1584*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1585*da0073e9SAndroid Build Coastguard Worker    def test_MaxPool2d_indices(self, device, dtype):
1586*da0073e9SAndroid Build Coastguard Worker        self._test_maxpool_indices(2, device=device, dtype=dtype)
1587*da0073e9SAndroid Build Coastguard Worker
1588*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1589*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1590*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1591*da0073e9SAndroid Build Coastguard Worker    def test_MaxPool3d_indices(self, device, dtype):
1592*da0073e9SAndroid Build Coastguard Worker        self._test_maxpool_indices(3, device=device, dtype=dtype)
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1595*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1596*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1597*da0073e9SAndroid Build Coastguard Worker    def test_AdaptiveMaxPool1d_indices(self, device, dtype):
1598*da0073e9SAndroid Build Coastguard Worker        self._test_maxpool_indices(1, adaptive=True, device=device, dtype=dtype)
1599*da0073e9SAndroid Build Coastguard Worker
1600*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1601*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1602*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1603*da0073e9SAndroid Build Coastguard Worker    def test_AdaptiveMaxPool2d_indices(self, device, dtype):
1604*da0073e9SAndroid Build Coastguard Worker        self._test_maxpool_indices(2, adaptive=True, device=device, dtype=dtype)
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1607*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1608*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1609*da0073e9SAndroid Build Coastguard Worker    def test_AdaptiveMaxPool3d_indices(self, device, dtype):
1610*da0073e9SAndroid Build Coastguard Worker        self._test_maxpool_indices(3, adaptive=True, device=device, dtype=dtype)
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1613*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1614*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1615*da0073e9SAndroid Build Coastguard Worker    def test_maxpool_indices_no_batch_dim(self, device, dtype):
1616*da0073e9SAndroid Build Coastguard Worker        """Check that indices with no batch dim is consistent with a single batch."""
1617*da0073e9SAndroid Build Coastguard Worker        max_pool_cases = [
1618*da0073e9SAndroid Build Coastguard Worker            (
1619*da0073e9SAndroid Build Coastguard Worker                nn.MaxPool1d(3, return_indices=True),
1620*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5, device=device, dtype=dtype),
1621*da0073e9SAndroid Build Coastguard Worker            ),
1622*da0073e9SAndroid Build Coastguard Worker            (
1623*da0073e9SAndroid Build Coastguard Worker                nn.MaxPool2d(3, return_indices=True),
1624*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5, 6, device=device, dtype=dtype),
1625*da0073e9SAndroid Build Coastguard Worker            ),
1626*da0073e9SAndroid Build Coastguard Worker            (
1627*da0073e9SAndroid Build Coastguard Worker                nn.MaxPool3d(3, return_indices=True),
1628*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5, 6, 7, device=device, dtype=dtype),
1629*da0073e9SAndroid Build Coastguard Worker            ),
1630*da0073e9SAndroid Build Coastguard Worker            (
1631*da0073e9SAndroid Build Coastguard Worker                nn.AdaptiveMaxPool1d(3, return_indices=True),
1632*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5, device=device, dtype=dtype),
1633*da0073e9SAndroid Build Coastguard Worker            ),
1634*da0073e9SAndroid Build Coastguard Worker            (
1635*da0073e9SAndroid Build Coastguard Worker                nn.AdaptiveMaxPool2d(3, return_indices=True),
1636*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5, 6, device=device, dtype=dtype),
1637*da0073e9SAndroid Build Coastguard Worker            ),
1638*da0073e9SAndroid Build Coastguard Worker            (
1639*da0073e9SAndroid Build Coastguard Worker                nn.AdaptiveMaxPool3d(3, return_indices=True),
1640*da0073e9SAndroid Build Coastguard Worker                torch.randn(3, 5, 6, 7, device=device, dtype=dtype),
1641*da0073e9SAndroid Build Coastguard Worker            ),
1642*da0073e9SAndroid Build Coastguard Worker        ]
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker        for module, input in max_pool_cases:
1645*da0073e9SAndroid Build Coastguard Worker            _, indices_no_batch = module(input)
1646*da0073e9SAndroid Build Coastguard Worker            _, indicies_single_batch = module(input.unsqueeze(0))
1647*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(indices_no_batch, indicies_single_batch.squeeze(0))
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1650*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1651*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes  # TODO: Fails on XLA
1652*da0073e9SAndroid Build Coastguard Worker    @gcIfJetson
1653*da0073e9SAndroid Build Coastguard Worker    def test_max_pool_nan_inf(self, device, dtype):
1654*da0073e9SAndroid Build Coastguard Worker        for adaptive in ["", "adaptive_"]:
1655*da0073e9SAndroid Build Coastguard Worker            for num_dim in [1, 2, 3]:
1656*da0073e9SAndroid Build Coastguard Worker                fn_name = f"{adaptive}max_pool{num_dim}d"
1657*da0073e9SAndroid Build Coastguard Worker                fn = getattr(F, fn_name)
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker                x = torch.full(
1660*da0073e9SAndroid Build Coastguard Worker                    [1, 1] + num_dim * [3],
1661*da0073e9SAndroid Build Coastguard Worker                    nan,
1662*da0073e9SAndroid Build Coastguard Worker                    device=device,
1663*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
1664*da0073e9SAndroid Build Coastguard Worker                    requires_grad=True,
1665*da0073e9SAndroid Build Coastguard Worker                )
1666*da0073e9SAndroid Build Coastguard Worker                res = fn(x, 1 if adaptive else 3)
1667*da0073e9SAndroid Build Coastguard Worker                res.backward(torch.randn_like(res))
1668*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isnan(res.item()))
1669*da0073e9SAndroid Build Coastguard Worker                x.requires_grad_(False)
1670*da0073e9SAndroid Build Coastguard Worker                res = fn(x, 1 if adaptive else 3)
1671*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isnan(res.item()))
1672*da0073e9SAndroid Build Coastguard Worker
1673*da0073e9SAndroid Build Coastguard Worker                x2 = torch.full(
1674*da0073e9SAndroid Build Coastguard Worker                    [1, 1] + num_dim * [3],
1675*da0073e9SAndroid Build Coastguard Worker                    -inf,
1676*da0073e9SAndroid Build Coastguard Worker                    device=device,
1677*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
1678*da0073e9SAndroid Build Coastguard Worker                    requires_grad=True,
1679*da0073e9SAndroid Build Coastguard Worker                )
1680*da0073e9SAndroid Build Coastguard Worker                res2 = fn(x2, 1 if adaptive else 3)
1681*da0073e9SAndroid Build Coastguard Worker                res2.backward(torch.randn_like(res2))
1682*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isinf(res2.item()))
1683*da0073e9SAndroid Build Coastguard Worker                x2.requires_grad_(False)
1684*da0073e9SAndroid Build Coastguard Worker                res2 = fn(x2, 1 if adaptive else 3)
1685*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(math.isinf(res2.item()))
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError: Unrecognized tensor type ID: Meta
1688*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1689*da0073e9SAndroid Build Coastguard Worker    def test_fractional_max_pool2d(self, device):
1690*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
1691*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, 2, 7, 7, requires_grad=True, device=device)
1692*da0073e9SAndroid Build Coastguard Worker            samples = x.new(1, 2, 2).uniform_()
1693*da0073e9SAndroid Build Coastguard Worker
1694*da0073e9SAndroid Build Coastguard Worker            def func(x):
1695*da0073e9SAndroid Build Coastguard Worker                return F.fractional_max_pool2d(
1696*da0073e9SAndroid Build Coastguard Worker                    x, (2, 2), output_size=(3, 3), _random_samples=samples
1697*da0073e9SAndroid Build Coastguard Worker                )
1698*da0073e9SAndroid Build Coastguard Worker
1699*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(x).shape, (1, 2, 3, 3))
1700*da0073e9SAndroid Build Coastguard Worker            gradcheck(func, [x])
1701*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(func, [x])
1702*da0073e9SAndroid Build Coastguard Worker
1703*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 7, 7, requires_grad=True, device=device)
1704*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(x).shape, (2, 3, 3))
1705*da0073e9SAndroid Build Coastguard Worker            if self.device_type != "cuda":
1706*da0073e9SAndroid Build Coastguard Worker                # Reference: https://github.com/pytorch/pytorch/issues/52427
1707*da0073e9SAndroid Build Coastguard Worker                # Raises -> RuntimeError: TensorAccessor expected 4 dims but tensor has 3
1708*da0073e9SAndroid Build Coastguard Worker                # on CUDA in gradcheck
1709*da0073e9SAndroid Build Coastguard Worker                gradcheck(func, [x])
1710*da0073e9SAndroid Build Coastguard Worker                gradgradcheck(func, [x])
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker            for kernel_size in [(), (1,)]:
1713*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "kernel_size must either"):
1714*da0073e9SAndroid Build Coastguard Worker                    # Incorrect kernel_size
1715*da0073e9SAndroid Build Coastguard Worker                    F.fractional_max_pool2d(
1716*da0073e9SAndroid Build Coastguard Worker                        x,
1717*da0073e9SAndroid Build Coastguard Worker                        kernel_size=kernel_size,
1718*da0073e9SAndroid Build Coastguard Worker                        output_size=(3, 3),
1719*da0073e9SAndroid Build Coastguard Worker                        _random_samples=samples,
1720*da0073e9SAndroid Build Coastguard Worker                    )
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker            err_large_msg = "too large relative to input "
1723*da0073e9SAndroid Build Coastguard Worker            err_out_size_msg = "output_size must either"
1724*da0073e9SAndroid Build Coastguard Worker            for output_size, msg in [
1725*da0073e9SAndroid Build Coastguard Worker                ((9, 3), err_large_msg + "height"),
1726*da0073e9SAndroid Build Coastguard Worker                ((3, 9), err_large_msg + "width"),
1727*da0073e9SAndroid Build Coastguard Worker                ((3,), err_out_size_msg),
1728*da0073e9SAndroid Build Coastguard Worker                ((), err_out_size_msg),
1729*da0073e9SAndroid Build Coastguard Worker            ]:
1730*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, msg):
1731*da0073e9SAndroid Build Coastguard Worker                    # Incorrect output_size
1732*da0073e9SAndroid Build Coastguard Worker                    F.fractional_max_pool2d(
1733*da0073e9SAndroid Build Coastguard Worker                        x, (2, 2), output_size=output_size, _random_samples=samples
1734*da0073e9SAndroid Build Coastguard Worker                    )
1735*da0073e9SAndroid Build Coastguard Worker
1736*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # RuntimeError: Unrecognized tensor type ID: Meta
1737*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1738*da0073e9SAndroid Build Coastguard Worker    def test_fractional_max_pool3d(self, device):
1739*da0073e9SAndroid Build Coastguard Worker        with set_default_dtype(torch.double):
1740*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, 2, 7, 7, 7, requires_grad=True, device=device)
1741*da0073e9SAndroid Build Coastguard Worker            samples = x.new(1, 2, 3).uniform_()
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker            def func(x):
1744*da0073e9SAndroid Build Coastguard Worker                return F.fractional_max_pool3d(
1745*da0073e9SAndroid Build Coastguard Worker                    x, (2, 2, 2), output_size=(3, 3, 3), _random_samples=samples
1746*da0073e9SAndroid Build Coastguard Worker                )
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(x).shape, (1, 2, 3, 3, 3))
1749*da0073e9SAndroid Build Coastguard Worker            gradcheck(func, [x])
1750*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(func, [x])
1751*da0073e9SAndroid Build Coastguard Worker
1752*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 7, 7, 7, requires_grad=True, device=device)
1753*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(x).shape, (2, 3, 3, 3))
1754*da0073e9SAndroid Build Coastguard Worker            gradcheck(func, [x])
1755*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(func, [x])
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Worker            for kernel_size in [(), (1,), (1, 1)]:
1758*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "kernel_size must either"):
1759*da0073e9SAndroid Build Coastguard Worker                    # Incorrect kernel_size
1760*da0073e9SAndroid Build Coastguard Worker                    F.fractional_max_pool3d(
1761*da0073e9SAndroid Build Coastguard Worker                        x,
1762*da0073e9SAndroid Build Coastguard Worker                        kernel_size=kernel_size,
1763*da0073e9SAndroid Build Coastguard Worker                        output_size=(3, 3, 3),
1764*da0073e9SAndroid Build Coastguard Worker                        _random_samples=samples,
1765*da0073e9SAndroid Build Coastguard Worker                    )
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker            err_large_msg = "too large relative to input "
1768*da0073e9SAndroid Build Coastguard Worker            err_out_size_msg = "output_size must either"
1769*da0073e9SAndroid Build Coastguard Worker            for output_size, msg in [
1770*da0073e9SAndroid Build Coastguard Worker                ((9, 3, 3), err_large_msg + "time"),
1771*da0073e9SAndroid Build Coastguard Worker                ((3, 9, 3), err_large_msg + "height"),
1772*da0073e9SAndroid Build Coastguard Worker                ((3, 3, 9), err_large_msg + "width"),
1773*da0073e9SAndroid Build Coastguard Worker                ((3, 3), err_out_size_msg),
1774*da0073e9SAndroid Build Coastguard Worker                ((3,), err_out_size_msg),
1775*da0073e9SAndroid Build Coastguard Worker                ((), err_out_size_msg),
1776*da0073e9SAndroid Build Coastguard Worker            ]:
1777*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, msg):
1778*da0073e9SAndroid Build Coastguard Worker                    # Incorrect output_size
1779*da0073e9SAndroid Build Coastguard Worker                    F.fractional_max_pool3d(
1780*da0073e9SAndroid Build Coastguard Worker                        x, (2, 2, 2), output_size=output_size, _random_samples=samples
1781*da0073e9SAndroid Build Coastguard Worker                    )
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
1784*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1785*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes  # TODO: Fails on XLA
1786*da0073e9SAndroid Build Coastguard Worker    def test_fractional_max_pool_nan_inf(self, device, dtype):
1787*da0073e9SAndroid Build Coastguard Worker        for num_dim in [2, 3]:
1788*da0073e9SAndroid Build Coastguard Worker            fn_name = f"FractionalMaxPool{num_dim}d"
1789*da0073e9SAndroid Build Coastguard Worker            fn = getattr(nn, fn_name)(kernel_size=2, output_size=1)
1790*da0073e9SAndroid Build Coastguard Worker            x = torch.full(
1791*da0073e9SAndroid Build Coastguard Worker                [1, 1] + num_dim * [3],
1792*da0073e9SAndroid Build Coastguard Worker                nan,
1793*da0073e9SAndroid Build Coastguard Worker                device=device,
1794*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
1795*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
1796*da0073e9SAndroid Build Coastguard Worker            )
1797*da0073e9SAndroid Build Coastguard Worker            res = fn(x)
1798*da0073e9SAndroid Build Coastguard Worker            res.backward(torch.randn_like(res))
1799*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(math.isnan(res.item()))
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker            x2 = torch.full(
1802*da0073e9SAndroid Build Coastguard Worker                [1, 1] + num_dim * [3],
1803*da0073e9SAndroid Build Coastguard Worker                -inf,
1804*da0073e9SAndroid Build Coastguard Worker                device=device,
1805*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
1806*da0073e9SAndroid Build Coastguard Worker                requires_grad=True,
1807*da0073e9SAndroid Build Coastguard Worker            )
1808*da0073e9SAndroid Build Coastguard Worker            res2 = fn(x2)
1809*da0073e9SAndroid Build Coastguard Worker            res2.backward(torch.randn_like(res2))
1810*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(math.isinf(res2.item()))
1811*da0073e9SAndroid Build Coastguard Worker
1812*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes  # TODO: RuntimeError message different on XLA
1813*da0073e9SAndroid Build Coastguard Worker    def test_pooling_zero_stride(self, device):
1814*da0073e9SAndroid Build Coastguard Worker        for op in ("max", "avg"):
1815*da0073e9SAndroid Build Coastguard Worker            for num_dim in [1, 2, 3]:
1816*da0073e9SAndroid Build Coastguard Worker                fn_name = f"{op}_pool{num_dim}d"
1817*da0073e9SAndroid Build Coastguard Worker                fn = getattr(F, fn_name)
1818*da0073e9SAndroid Build Coastguard Worker                x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float)
1819*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
1820*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1821*da0073e9SAndroid Build Coastguard Worker                    r"stride should not be zero|stride must be greater than zero",
1822*da0073e9SAndroid Build Coastguard Worker                    lambda: fn(x, kernel_size=2, stride=0),
1823*da0073e9SAndroid Build Coastguard Worker                )
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker                fn_module_name = f"{op.title()}Pool{num_dim}d"
1826*da0073e9SAndroid Build Coastguard Worker                fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0)
1827*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
1828*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1829*da0073e9SAndroid Build Coastguard Worker                    r"stride should not be zero|stride must be greater than zero",
1830*da0073e9SAndroid Build Coastguard Worker                    lambda: fn_module(x),
1831*da0073e9SAndroid Build Coastguard Worker                )
1832*da0073e9SAndroid Build Coastguard Worker
1833*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1834*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1835*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1836*da0073e9SAndroid Build Coastguard Worker    def test_pool_large_size(self, device, dtype):
1837*da0073e9SAndroid Build Coastguard Worker        for op in ("max", "avg"):
1838*da0073e9SAndroid Build Coastguard Worker            for num_dim in [1, 2, 3]:
1839*da0073e9SAndroid Build Coastguard Worker                fn_name = f"{op}_pool{num_dim}d"
1840*da0073e9SAndroid Build Coastguard Worker                fn = getattr(F, fn_name)
1841*da0073e9SAndroid Build Coastguard Worker                # 16777217 is the smallest integer not expressible in float32
1842*da0073e9SAndroid Build Coastguard Worker                x = torch.ones(
1843*da0073e9SAndroid Build Coastguard Worker                    [1, 1, 16777217] + (num_dim - 1) * [1], device=device, dtype=dtype
1844*da0073e9SAndroid Build Coastguard Worker                )
1845*da0073e9SAndroid Build Coastguard Worker                res = fn(x, 1, stride=1, padding=0)
1846*da0073e9SAndroid Build Coastguard Worker                # check if the output shape was still computed correctly
1847*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.shape[2], res.shape[2])
1848*da0073e9SAndroid Build Coastguard Worker
1849*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1850*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("6GB")
1851*da0073e9SAndroid Build Coastguard Worker    def test_pooling_large(self, device):
1852*da0073e9SAndroid Build Coastguard Worker        def helper(pool):
1853*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(
1854*da0073e9SAndroid Build Coastguard Worker                2**7 + 10, 2**8, 2**8, 2**8, dtype=torch.half, device="cuda"
1855*da0073e9SAndroid Build Coastguard Worker            )
1856*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(inp.numel() > 2**31 - 1)
1857*da0073e9SAndroid Build Coastguard Worker            out = pool(inp)
1858*da0073e9SAndroid Build Coastguard Worker            torch.cuda.synchronize()  # asserts test finishes normally without raising errors
1859*da0073e9SAndroid Build Coastguard Worker
1860*da0073e9SAndroid Build Coastguard Worker        helper(nn.MaxPool2d(4, 4))
1861*da0073e9SAndroid Build Coastguard Worker        helper(nn.AvgPool2d(4, 4))
1862*da0073e9SAndroid Build Coastguard Worker        helper(nn.FractionalMaxPool2d(4, 4))
1863*da0073e9SAndroid Build Coastguard Worker        helper(nn.AdaptiveMaxPool2d((2**6, 2**6)))
1864*da0073e9SAndroid Build Coastguard Worker        helper(nn.AdaptiveAvgPool2d((2**6, 2**6)))
1865*da0073e9SAndroid Build Coastguard Worker
1866*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
1867*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
1868*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1869*da0073e9SAndroid Build Coastguard Worker    def test_pool_invalid_size(self, device, dtype):
1870*da0073e9SAndroid Build Coastguard Worker        for op in ("max", "avg"):
1871*da0073e9SAndroid Build Coastguard Worker            for num_dim in [1, 2, 3]:
1872*da0073e9SAndroid Build Coastguard Worker                fn_name = f"{op}_pool{num_dim}d"
1873*da0073e9SAndroid Build Coastguard Worker                if op == "max":
1874*da0073e9SAndroid Build Coastguard Worker                    # New implementation without indices supports empty tensors
1875*da0073e9SAndroid Build Coastguard Worker                    # TODO(Heitor) change once with_indices code is updated
1876*da0073e9SAndroid Build Coastguard Worker                    fn_name += "_with_indices"
1877*da0073e9SAndroid Build Coastguard Worker                fn = getattr(F, fn_name)
1878*da0073e9SAndroid Build Coastguard Worker                # use a configuration that gives zero outputs only
1879*da0073e9SAndroid Build Coastguard Worker                # when doing a correct floor division by the stride
1880*da0073e9SAndroid Build Coastguard Worker                x = torch.ones([1, 1] + num_dim * [4], device=device, dtype=dtype)
1881*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, r"too small|smaller than"):
1882*da0073e9SAndroid Build Coastguard Worker                    try:
1883*da0073e9SAndroid Build Coastguard Worker                        res = fn(x, 3, stride=2, padding=0, dilation=2)
1884*da0073e9SAndroid Build Coastguard Worker                    except TypeError:
1885*da0073e9SAndroid Build Coastguard Worker                        # some implementations do not support dilation
1886*da0073e9SAndroid Build Coastguard Worker                        res = fn(x, 6, stride=2, padding=0)
1887*da0073e9SAndroid Build Coastguard Worker
1888*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1889*da0073e9SAndroid Build Coastguard Worker    def test_pooling_bfloat16(self, device):
1890*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(
1891*da0073e9SAndroid Build Coastguard Worker            self,
1892*da0073e9SAndroid Build Coastguard Worker            torch.nn.AvgPool1d(3, stride=2),
1893*da0073e9SAndroid Build Coastguard Worker            device,
1894*da0073e9SAndroid Build Coastguard Worker            inp_dims=(8, 4, 16),
1895*da0073e9SAndroid Build Coastguard Worker            prec=0.05,
1896*da0073e9SAndroid Build Coastguard Worker        )
1897*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(
1898*da0073e9SAndroid Build Coastguard Worker            self,
1899*da0073e9SAndroid Build Coastguard Worker            torch.nn.AvgPool2d(3, stride=2),
1900*da0073e9SAndroid Build Coastguard Worker            device,
1901*da0073e9SAndroid Build Coastguard Worker            inp_dims=(8, 4, 16, 16),
1902*da0073e9SAndroid Build Coastguard Worker            prec=0.05,
1903*da0073e9SAndroid Build Coastguard Worker        )
1904*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(
1905*da0073e9SAndroid Build Coastguard Worker            self,
1906*da0073e9SAndroid Build Coastguard Worker            torch.nn.AvgPool3d(3, stride=2),
1907*da0073e9SAndroid Build Coastguard Worker            device,
1908*da0073e9SAndroid Build Coastguard Worker            inp_dims=(8, 4, 16, 16, 16),
1909*da0073e9SAndroid Build Coastguard Worker            prec=0.05,
1910*da0073e9SAndroid Build Coastguard Worker        )
1911*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(
1912*da0073e9SAndroid Build Coastguard Worker            self, torch.nn.AdaptiveAvgPool1d(3), device, inp_dims=(8, 4, 16), prec=0.05
1913*da0073e9SAndroid Build Coastguard Worker        )
1914*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(
1915*da0073e9SAndroid Build Coastguard Worker            self,
1916*da0073e9SAndroid Build Coastguard Worker            torch.nn.AdaptiveAvgPool2d((3, 5)),
1917*da0073e9SAndroid Build Coastguard Worker            device,
1918*da0073e9SAndroid Build Coastguard Worker            inp_dims=(8, 4, 16, 16),
1919*da0073e9SAndroid Build Coastguard Worker            prec=0.05,
1920*da0073e9SAndroid Build Coastguard Worker        )
1921*da0073e9SAndroid Build Coastguard Worker        _test_bfloat16_ops(
1922*da0073e9SAndroid Build Coastguard Worker            self,
1923*da0073e9SAndroid Build Coastguard Worker            torch.nn.AdaptiveAvgPool3d((3, 5, 7)),
1924*da0073e9SAndroid Build Coastguard Worker            device,
1925*da0073e9SAndroid Build Coastguard Worker            inp_dims=(8, 4, 16, 16, 16),
1926*da0073e9SAndroid Build Coastguard Worker            prec=0.05,
1927*da0073e9SAndroid Build Coastguard Worker        )
1928*da0073e9SAndroid Build Coastguard Worker
1929*da0073e9SAndroid Build Coastguard Worker    def test_maxpool3d_non_square_backward(self, device):
1930*da0073e9SAndroid Build Coastguard Worker        # previous CUDA routine of this backward calculates kernel launch grid size
1931*da0073e9SAndroid Build Coastguard Worker        # with last two dimensions interchanged, so the tailing along the longer dim
1932*da0073e9SAndroid Build Coastguard Worker        # get ignored. Here we test whether every position gets gradient.
1933*da0073e9SAndroid Build Coastguard Worker        for dim in (2, 3, 4):
1934*da0073e9SAndroid Build Coastguard Worker            shape = tuple(32 if i != dim else 256 for i in range(4))
1935*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(shape, device=device, requires_grad=True)
1936*da0073e9SAndroid Build Coastguard Worker            F.max_pool3d(x, kernel_size=(1, 1, 1)).sum().backward()
1937*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, torch.ones_like(x.grad))
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker    @slowTest
1940*da0073e9SAndroid Build Coastguard Worker    def test_adaptive_pool_odd_size(self, device):
1941*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/81409
1942*da0073e9SAndroid Build Coastguard Worker        Ih, Iw, Oh, Ow = 5873, 3693, 3527, 2219
1943*da0073e9SAndroid Build Coastguard Worker        imgs = torch.randint(low=0, high=256, size=(11, Ih, Iw), dtype=torch.float)
1944*da0073e9SAndroid Build Coastguard Worker        imgs_ = F.adaptive_avg_pool2d(imgs, (Oh, Ow))
1945*da0073e9SAndroid Build Coastguard Worker        imgs_ = F.adaptive_max_pool2d(imgs, (Oh, Ow))
1946*da0073e9SAndroid Build Coastguard Worker
1947*da0073e9SAndroid Build Coastguard Worker        Id, Ih, Iw, Od, Oh, Ow = 3, 5873, 3693, 3, 3527, 2219
1948*da0073e9SAndroid Build Coastguard Worker        imgs = torch.randint(low=0, high=256, size=(3, Id, Ih, Iw), dtype=torch.float)
1949*da0073e9SAndroid Build Coastguard Worker        imgs_ = F.adaptive_avg_pool3d(imgs, (Od, Oh, Ow))
1950*da0073e9SAndroid Build Coastguard Worker        imgs_ = F.adaptive_max_pool3d(imgs, (Od, Oh, Ow))
1951*da0073e9SAndroid Build Coastguard Worker
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestPoolingNNDeviceType, globals())
1954*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestPoolingNN)
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1957*da0073e9SAndroid Build Coastguard Worker    run_tests()
1958