1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: nn"] 2*da0073e9SAndroid Build Coastguard Workerimport itertools 3*da0073e9SAndroid Build Coastguard Workerimport math 4*da0073e9SAndroid Build Coastguard Workerimport operator 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport random 7*da0073e9SAndroid Build Coastguard Workerimport subprocess 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport unittest 10*da0073e9SAndroid Build Coastguard Workerfrom functools import partial, reduce 11*da0073e9SAndroid Build Coastguard Workerfrom itertools import repeat 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport torch 14*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 15*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 16*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 17*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import gradcheck, gradgradcheck 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 21*da0073e9SAndroid Build Coastguard Worker dtypes, 22*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 23*da0073e9SAndroid Build Coastguard Worker expectedFailureMeta, 24*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 25*da0073e9SAndroid Build Coastguard Worker largeTensorTest, 26*da0073e9SAndroid Build Coastguard Worker onlyCPU, 27*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 28*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 29*da0073e9SAndroid Build Coastguard Worker skipCUDAIfRocm, 30*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and 33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import ( 34*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops, 35*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input, 36*da0073e9SAndroid Build Coastguard Worker NNTestCase, 37*da0073e9SAndroid Build Coastguard Worker) 38*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 39*da0073e9SAndroid Build Coastguard Worker gcIfJetson, 40*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 41*da0073e9SAndroid Build Coastguard Worker parametrize as parametrize_test, 42*da0073e9SAndroid Build Coastguard Worker run_tests, 43*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 44*da0073e9SAndroid Build Coastguard Worker skipIfMps, 45*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 46*da0073e9SAndroid Build Coastguard Worker slowTest, 47*da0073e9SAndroid Build Coastguard Worker subtest, 48*da0073e9SAndroid Build Coastguard Worker TEST_WITH_UBSAN, 49*da0073e9SAndroid Build Coastguard Worker TestCase, 50*da0073e9SAndroid Build Coastguard Worker) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerclass TestAvgPool(TestCase): 54*da0073e9SAndroid Build Coastguard Worker def _sum_pool2d(self, x, kernel_size): 55*da0073e9SAndroid Build Coastguard Worker windows = torch.nn.functional.unfold( 56*da0073e9SAndroid Build Coastguard Worker x, kernel_size=kernel_size, stride=kernel_size 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker return torch.sum(windows, dim=1) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker def _sum_pool3d(self, x, kernel_size): 61*da0073e9SAndroid Build Coastguard Worker # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum 62*da0073e9SAndroid Build Coastguard Worker h = kernel_size[0] 63*da0073e9SAndroid Build Coastguard Worker splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h] 64*da0073e9SAndroid Build Coastguard Worker # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times 65*da0073e9SAndroid Build Coastguard Worker splited_x = [ 66*da0073e9SAndroid Build Coastguard Worker self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) 67*da0073e9SAndroid Build Coastguard Worker for t in splited_x 68*da0073e9SAndroid Build Coastguard Worker ] 69*da0073e9SAndroid Build Coastguard Worker joined_x = torch.cat(splited_x) 70*da0073e9SAndroid Build Coastguard Worker return joined_x.view(1, joined_x.numel()) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def _avg_pool2d(self, x, kernel_size): 73*da0073e9SAndroid Build Coastguard Worker size = reduce(operator.mul, kernel_size) 74*da0073e9SAndroid Build Coastguard Worker return self._sum_pool2d(x, kernel_size) / size 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def _avg_pool3d(self, x, kernel_size): 77*da0073e9SAndroid Build Coastguard Worker size = reduce(operator.mul, kernel_size) 78*da0073e9SAndroid Build Coastguard Worker return self._sum_pool3d(x, kernel_size) / size 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def test_doubletensor_avg_pool2d(self): 81*da0073e9SAndroid Build Coastguard Worker n, m = 5, 8 82*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 1, n, m, dtype=torch.double) 83*da0073e9SAndroid Build Coastguard Worker for i in range(1, n + 1): 84*da0073e9SAndroid Build Coastguard Worker for j in range(1, m + 1): 85*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.avg_pool2d(input[0], (i, j)) 86*da0073e9SAndroid Build Coastguard Worker actual = actual.view(1, actual.numel()) 87*da0073e9SAndroid Build Coastguard Worker expected = self._avg_pool2d(input, (i, j)) 88*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=0, atol=1e-5) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def test_doubletensor_avg_pool2d_with_divisor(self): 91*da0073e9SAndroid Build Coastguard Worker n, m = 3, 3 92*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 1, n, m, dtype=torch.double) 93*da0073e9SAndroid Build Coastguard Worker for i in range(1, n + 1): 94*da0073e9SAndroid Build Coastguard Worker for j in range(1, m + 1): 95*da0073e9SAndroid Build Coastguard Worker for divisor in [1, 7, i * j]: 96*da0073e9SAndroid Build Coastguard Worker actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor) 97*da0073e9SAndroid Build Coastguard Worker actual = actual.view(1, actual.numel()) 98*da0073e9SAndroid Build Coastguard Worker expected = self._sum_pool2d(input, (i, j)) / divisor 99*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=0, atol=1e-5) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def test_doubletensor_avg_pool3d(self): 102*da0073e9SAndroid Build Coastguard Worker h, w, d = 5, 6, 7 103*da0073e9SAndroid Build Coastguard Worker input = torch.rand(h, w, d, dtype=torch.double) 104*da0073e9SAndroid Build Coastguard Worker for i in range(1, h + 1): 105*da0073e9SAndroid Build Coastguard Worker for j in range(1, w + 1): 106*da0073e9SAndroid Build Coastguard Worker for k in range(1, d + 1): 107*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.avg_pool3d( 108*da0073e9SAndroid Build Coastguard Worker input.unsqueeze(0), (i, j, k) 109*da0073e9SAndroid Build Coastguard Worker ) 110*da0073e9SAndroid Build Coastguard Worker actual = actual.view(1, actual.numel()) 111*da0073e9SAndroid Build Coastguard Worker expected = self._avg_pool3d(input, (i, j, k)) 112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=0, atol=1e-5) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def test_doubletensor_avg_pool3d_with_divisor(self): 115*da0073e9SAndroid Build Coastguard Worker h, w, d = 6, 5, 7 116*da0073e9SAndroid Build Coastguard Worker input = torch.rand(h, w, d, dtype=torch.double) 117*da0073e9SAndroid Build Coastguard Worker for i in range(1, h + 1): 118*da0073e9SAndroid Build Coastguard Worker for j in range(1, w + 1): 119*da0073e9SAndroid Build Coastguard Worker for k in range(1, d + 1): 120*da0073e9SAndroid Build Coastguard Worker for divisor in [1, 7, i * j]: 121*da0073e9SAndroid Build Coastguard Worker actual = torch.nn.functional.avg_pool3d( 122*da0073e9SAndroid Build Coastguard Worker input.unsqueeze(0), (i, j, k), divisor_override=divisor 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker actual = actual.view(1, actual.numel()) 125*da0073e9SAndroid Build Coastguard Worker expected = self._sum_pool3d(input, (i, j, k)) / divisor 126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, rtol=0, atol=1e-5) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def test_avg_pool1d_ceil_mode(self): 129*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-36977 130*da0073e9SAndroid Build Coastguard Worker x = 10 * torch.randn((1, 16, 4)) 131*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool1d( 132*da0073e9SAndroid Build Coastguard Worker x, ceil_mode=True, count_include_pad=True, kernel_size=1, stride=2 133*da0073e9SAndroid Build Coastguard Worker ) 134*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not torch.isnan(y).any()) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 137*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool1d( 138*da0073e9SAndroid Build Coastguard Worker x.to("cuda"), 139*da0073e9SAndroid Build Coastguard Worker ceil_mode=True, 140*da0073e9SAndroid Build Coastguard Worker count_include_pad=True, 141*da0073e9SAndroid Build Coastguard Worker kernel_size=1, 142*da0073e9SAndroid Build Coastguard Worker stride=2, 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not torch.isnan(y).any()) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_ceil_mode(self): 147*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-36977 148*da0073e9SAndroid Build Coastguard Worker x = 10 * torch.randn((1, 16, 4, 4)) 149*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool2d( 150*da0073e9SAndroid Build Coastguard Worker x, 151*da0073e9SAndroid Build Coastguard Worker ceil_mode=True, 152*da0073e9SAndroid Build Coastguard Worker count_include_pad=True, 153*da0073e9SAndroid Build Coastguard Worker kernel_size=(1, 2), 154*da0073e9SAndroid Build Coastguard Worker padding=(0, 1), 155*da0073e9SAndroid Build Coastguard Worker stride=2, 156*da0073e9SAndroid Build Coastguard Worker ) 157*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not torch.isnan(y).any()) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 160*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool2d( 161*da0073e9SAndroid Build Coastguard Worker x.to("cuda"), 162*da0073e9SAndroid Build Coastguard Worker ceil_mode=True, 163*da0073e9SAndroid Build Coastguard Worker count_include_pad=True, 164*da0073e9SAndroid Build Coastguard Worker kernel_size=(1, 2), 165*da0073e9SAndroid Build Coastguard Worker padding=(0, 1), 166*da0073e9SAndroid Build Coastguard Worker stride=2, 167*da0073e9SAndroid Build Coastguard Worker ) 168*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not torch.isnan(y).any()) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def test_avg_pool3d_ceil_mode(self): 171*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-36977 172*da0073e9SAndroid Build Coastguard Worker x = 10 * torch.randn((1, 16, 4, 4, 4)) 173*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool3d( 174*da0073e9SAndroid Build Coastguard Worker x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2, 3), stride=2 175*da0073e9SAndroid Build Coastguard Worker ) 176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not torch.isnan(y).any()) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 179*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.avg_pool3d( 180*da0073e9SAndroid Build Coastguard Worker x.to("cuda"), 181*da0073e9SAndroid Build Coastguard Worker ceil_mode=True, 182*da0073e9SAndroid Build Coastguard Worker count_include_pad=True, 183*da0073e9SAndroid Build Coastguard Worker kernel_size=(1, 2, 3), 184*da0073e9SAndroid Build Coastguard Worker stride=2, 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not torch.isnan(y).any()) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Workerclass TestPoolingNN(NNTestCase): 190*da0073e9SAndroid Build Coastguard Worker _do_cuda_memory_leak_check = True 191*da0073e9SAndroid Build Coastguard Worker _do_cuda_non_default_stream = True 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_size_none(self): 194*da0073e9SAndroid Build Coastguard Worker for numel in (2, 3): 195*da0073e9SAndroid Build Coastguard Worker for pool_type in ("Max", "Avg"): 196*da0073e9SAndroid Build Coastguard Worker cls_name = f"Adaptive{pool_type}Pool{numel}d" 197*da0073e9SAndroid Build Coastguard Worker module_cls = getattr(nn, cls_name) 198*da0073e9SAndroid Build Coastguard Worker output_size = (2,) * (numel - 1) + (None,) 199*da0073e9SAndroid Build Coastguard Worker module = module_cls(output_size) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker input = torch.randn((4,) * (numel + 1)) 202*da0073e9SAndroid Build Coastguard Worker output = module(input) 203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.size(), (4,) + (2,) * (numel - 1) + (4,)) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_UBSAN, "signed integer overflow error with UBSAN") 206*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_size_overflow(self): 207*da0073e9SAndroid Build Coastguard Worker # 0x0x3fffffffffffffff * 2 * 2 = 0xfffffffffffffffc = -4 as int64_t 208*da0073e9SAndroid Build Coastguard Worker # Tensor::numel() return int64_t, so following check that negative allocs are correctly handled 209*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 210*da0073e9SAndroid Build Coastguard Worker RuntimeError, 211*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.AdaptiveMaxPool1d(0x3FFFFFFFFFFFFFFF)( 212*da0073e9SAndroid Build Coastguard Worker torch.empty([2, 2, 2]) 213*da0073e9SAndroid Build Coastguard Worker ), 214*da0073e9SAndroid Build Coastguard Worker ) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_avg_nhwc(self): 217*da0073e9SAndroid Build Coastguard Worker device_list = ["cpu"] 218*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 219*da0073e9SAndroid Build Coastguard Worker device_list.append("cuda") 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker for device in device_list: 222*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) 223*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).requires_grad_() 224*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) 225*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 228*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 229*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker out = pool(input) 232*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 233*da0073e9SAndroid Build Coastguard Worker ref_out = ref_pool(ref_input) 234*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 237*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_avg_nhwc_non_contiguous(self): 242*da0073e9SAndroid Build Coastguard Worker device_list = ["cpu"] 243*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA: 244*da0073e9SAndroid Build Coastguard Worker device_list.append("cuda") 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker for device in device_list: 247*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device) 248*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last) 249*da0073e9SAndroid Build Coastguard Worker input = input[:, ::2, :, :].requires_grad_() 250*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device) 251*da0073e9SAndroid Build Coastguard Worker grad = grad[:, ::2, :, :] 252*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 255*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 256*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker out = pool(input) 259*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 260*da0073e9SAndroid Build Coastguard Worker ref_out = ref_pool(ref_input) 261*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 264*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_lower_precision(self): 269*da0073e9SAndroid Build Coastguard Worker def _test_adaptive_pooling_lower_precision( 270*da0073e9SAndroid Build Coastguard Worker self, device, dtype, mod, memory_format 271*da0073e9SAndroid Build Coastguard Worker ): 272*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, (3, 19, 8, 8), dtype=torch.float32) 273*da0073e9SAndroid Build Coastguard Worker input = input.to(device).to(memory_format=memory_format).requires_grad_() 274*da0073e9SAndroid Build Coastguard Worker pool = mod((7, 7)).to(device) 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().to(dtype=dtype).requires_grad_(True) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker out = pool(input) 279*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 280*da0073e9SAndroid Build Coastguard Worker out2 = pool(input2) 281*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out2.is_contiguous(memory_format=memory_format)) 284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2.dtype, dtype) 285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input2.grad.dtype, dtype) 286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2.float(), atol=0.1, rtol=0) 287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input2.grad.float(), atol=0.1, rtol=0) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker device_list = ["cpu"] 290*da0073e9SAndroid Build Coastguard Worker for device in device_list: 291*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.bfloat16, torch.float16]: 292*da0073e9SAndroid Build Coastguard Worker _test_adaptive_pooling_lower_precision( 293*da0073e9SAndroid Build Coastguard Worker self, 294*da0073e9SAndroid Build Coastguard Worker device, 295*da0073e9SAndroid Build Coastguard Worker dtype, 296*da0073e9SAndroid Build Coastguard Worker torch.nn.AdaptiveAvgPool2d, 297*da0073e9SAndroid Build Coastguard Worker torch.contiguous_format, 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker _test_adaptive_pooling_lower_precision( 300*da0073e9SAndroid Build Coastguard Worker self, device, dtype, torch.nn.AdaptiveAvgPool2d, torch.channels_last 301*da0073e9SAndroid Build Coastguard Worker ) 302*da0073e9SAndroid Build Coastguard Worker _test_adaptive_pooling_lower_precision( 303*da0073e9SAndroid Build Coastguard Worker self, 304*da0073e9SAndroid Build Coastguard Worker device, 305*da0073e9SAndroid Build Coastguard Worker dtype, 306*da0073e9SAndroid Build Coastguard Worker torch.nn.AdaptiveMaxPool2d, 307*da0073e9SAndroid Build Coastguard Worker torch.contiguous_format, 308*da0073e9SAndroid Build Coastguard Worker ) 309*da0073e9SAndroid Build Coastguard Worker _test_adaptive_pooling_lower_precision( 310*da0073e9SAndroid Build Coastguard Worker self, device, dtype, torch.nn.AdaptiveMaxPool2d, torch.channels_last 311*da0073e9SAndroid Build Coastguard Worker ) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 314*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("12GB", device="cuda") 315*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_avg_nhwc_launch_config_backward(self): 316*da0073e9SAndroid Build Coastguard Worker input = torch.randint( 317*da0073e9SAndroid Build Coastguard Worker 1, 10, (1, 32, 2**17 + 1, 32), dtype=torch.float32, device="cuda" 318*da0073e9SAndroid Build Coastguard Worker ) 319*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).requires_grad_() 320*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(1, 10, (1, 32, 10, 32), dtype=torch.float32, device="cuda") 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda() 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 325*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 326*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda() 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker out = pool(input) 329*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 330*da0073e9SAndroid Build Coastguard Worker ref_out = ref_pool(ref_input) 331*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 334*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 339*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("12GB", device="cuda") 340*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_avg_nhwc_launch_config_forward(self): 341*da0073e9SAndroid Build Coastguard Worker input = torch.randint( 342*da0073e9SAndroid Build Coastguard Worker 1, 10, (1, 32, 16, 16), dtype=torch.float32, device="cuda" 343*da0073e9SAndroid Build Coastguard Worker ) 344*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).requires_grad_() 345*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.AdaptiveAvgPool2d((2**17 + 1, 32)).cuda() 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 348*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.AdaptiveAvgPool2d((2**17 + 1, 32)).cuda() 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker out = pool(input) 351*da0073e9SAndroid Build Coastguard Worker ref_out = ref_pool(ref_input) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 354*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 358*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pooling_overflow(self): 359*da0073e9SAndroid Build Coastguard Worker input = torch.randint( 360*da0073e9SAndroid Build Coastguard Worker -256, 256, (20, 32, 256, 256), dtype=torch.half, device="cuda" 361*da0073e9SAndroid Build Coastguard Worker ) 362*da0073e9SAndroid Build Coastguard Worker avg_pool = torch.nn.AdaptiveAvgPool2d((2, 2)) 363*da0073e9SAndroid Build Coastguard Worker out = avg_pool(input) 364*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isinf(out).any()) 365*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isnan(out).any()) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 368*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pooling_nhwc_overflow(self): 369*da0073e9SAndroid Build Coastguard Worker input = torch.randint( 370*da0073e9SAndroid Build Coastguard Worker -256, 256, (20, 32, 256, 256), dtype=torch.half, device="cuda" 371*da0073e9SAndroid Build Coastguard Worker ) 372*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last) 373*da0073e9SAndroid Build Coastguard Worker avg_pool = torch.nn.AdaptiveAvgPool2d((2, 2)) 374*da0073e9SAndroid Build Coastguard Worker out = avg_pool(input) 375*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isinf(out).any()) 376*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isnan(out).any()) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def test_MaxUnpool2d_output_size(self): 379*da0073e9SAndroid Build Coastguard Worker m = nn.MaxPool2d(3, stride=2, return_indices=True) 380*da0073e9SAndroid Build Coastguard Worker mu = nn.MaxUnpool2d(3, stride=2) 381*da0073e9SAndroid Build Coastguard Worker big_t = torch.rand(1, 1, 6, 6) 382*da0073e9SAndroid Build Coastguard Worker big_t[0][0][4][4] = 100 383*da0073e9SAndroid Build Coastguard Worker output_big, indices_big = m(big_t) 384*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big)) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker small_t = torch.rand(1, 1, 5, 5) 387*da0073e9SAndroid Build Coastguard Worker for i in range(0, 4, 2): 388*da0073e9SAndroid Build Coastguard Worker for j in range(0, 4, 2): 389*da0073e9SAndroid Build Coastguard Worker small_t[:, :, i, j] = 100 390*da0073e9SAndroid Build Coastguard Worker output_small, indices_small = m(small_t) 391*da0073e9SAndroid Build Coastguard Worker for h in range(3, 10): 392*da0073e9SAndroid Build Coastguard Worker for w in range(3, 10): 393*da0073e9SAndroid Build Coastguard Worker if 4 <= h <= 6 and 4 <= w <= 6: 394*da0073e9SAndroid Build Coastguard Worker size = (h, w) 395*da0073e9SAndroid Build Coastguard Worker if h == 6: 396*da0073e9SAndroid Build Coastguard Worker size = (1, 1) + size 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker mu(output_small, indices_small, output_size=size) 399*da0073e9SAndroid Build Coastguard Worker else: 400*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 401*da0073e9SAndroid Build Coastguard Worker ValueError, lambda: mu(output_small, indices_small, (h, w)) 402*da0073e9SAndroid Build Coastguard Worker ) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def test_max_unpool2d_nhwc_cpu(self): 405*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 10, 9, 9).float().cpu() 406*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last) 407*da0073e9SAndroid Build Coastguard Worker ref_input = input.clone().contiguous() 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu() 410*da0073e9SAndroid Build Coastguard Worker ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu() 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker out, ind = pool(input) 413*da0073e9SAndroid Build Coastguard Worker ref_out, ref_ind = ref_pool(ref_input) 414*da0073e9SAndroid Build Coastguard Worker out.requires_grad_() 415*da0073e9SAndroid Build Coastguard Worker ref_out.requires_grad_() 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker unpool = nn.MaxUnpool2d(3, stride=2).cpu() 418*da0073e9SAndroid Build Coastguard Worker ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu() 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker upout = unpool(out, ind) 421*da0073e9SAndroid Build Coastguard Worker ref_upout = ref_unpool(ref_out, ref_ind) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(upout.size()).float().cpu() 424*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=torch.channels_last) 425*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.clone().contiguous() 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker upout.backward(grad) 428*da0073e9SAndroid Build Coastguard Worker ref_upout.backward(ref_grad) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last)) 431*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_upout.is_contiguous()) 432*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(upout, ref_upout)) 433*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out.grad, ref_out.grad)) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker def test_max_unpool(self): 436*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 437*da0073e9SAndroid Build Coastguard Worker # Test 1D 438*da0073e9SAndroid Build Coastguard Worker output, indices = F.max_pool1d( 439*da0073e9SAndroid Build Coastguard Worker torch.randn([1, 1, 4]), 2, stride=2, return_indices=True 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 442*da0073e9SAndroid Build Coastguard Worker F.max_unpool1d(output, indices, 2), 443*da0073e9SAndroid Build Coastguard Worker F.max_unpool1d(output, indices, 2, stride=2), 444*da0073e9SAndroid Build Coastguard Worker ) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker # Test list / tuple passed as argument to max_unpool1d 447*da0073e9SAndroid Build Coastguard Worker input = torch.randn([1, 1, 5], requires_grad=True) 448*da0073e9SAndroid Build Coastguard Worker output, indices = F.max_pool1d(input, 2, stride=2, return_indices=True) 449*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 450*da0073e9SAndroid Build Coastguard Worker F.max_unpool1d(output, indices, 2, stride=2, output_size=input.shape), 451*da0073e9SAndroid Build Coastguard Worker F.max_unpool1d(output, indices, 2, stride=2, output_size=input.size()), 452*da0073e9SAndroid Build Coastguard Worker ) 453*da0073e9SAndroid Build Coastguard Worker gradcheck(F.max_unpool1d, (output, indices, 2), check_forward_ad=True) 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker # Test 2D 456*da0073e9SAndroid Build Coastguard Worker output, indices = F.max_pool2d( 457*da0073e9SAndroid Build Coastguard Worker torch.randn([1, 1, 4, 4], requires_grad=True), 458*da0073e9SAndroid Build Coastguard Worker 2, 459*da0073e9SAndroid Build Coastguard Worker stride=2, 460*da0073e9SAndroid Build Coastguard Worker return_indices=True, 461*da0073e9SAndroid Build Coastguard Worker ) 462*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 463*da0073e9SAndroid Build Coastguard Worker F.max_unpool2d(output, indices, 2), 464*da0073e9SAndroid Build Coastguard Worker F.max_unpool2d(output, indices, 2, stride=2), 465*da0073e9SAndroid Build Coastguard Worker ) 466*da0073e9SAndroid Build Coastguard Worker gradcheck(F.max_unpool2d, (output, indices, 2), check_forward_ad=True) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker # Test 3D 469*da0073e9SAndroid Build Coastguard Worker output, indices = F.max_pool3d( 470*da0073e9SAndroid Build Coastguard Worker torch.randn([4, 4, 4, 4, 4], requires_grad=True), 471*da0073e9SAndroid Build Coastguard Worker 2, 472*da0073e9SAndroid Build Coastguard Worker stride=2, 473*da0073e9SAndroid Build Coastguard Worker return_indices=True, 474*da0073e9SAndroid Build Coastguard Worker ) 475*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 476*da0073e9SAndroid Build Coastguard Worker F.max_unpool3d(output, indices, 2), 477*da0073e9SAndroid Build Coastguard Worker F.max_unpool3d(output, indices, 2, stride=2), 478*da0073e9SAndroid Build Coastguard Worker ) 479*da0073e9SAndroid Build Coastguard Worker gradcheck(F.max_unpool3d, (output, indices, 2), check_forward_ad=True) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker def test_max_unpool3d_input_check(self): 482*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, 3, 1, 1, 1) 483*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 484*da0073e9SAndroid Build Coastguard Worker F.max_unpool3d(x, torch.zeros(x.shape, dtype=int), [1, 1]) 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker def test_quantized_max_pool1d_empty_kernel(self): 487*da0073e9SAndroid Build Coastguard Worker # This used to segfault when called with an empty kernel 488*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/116323 489*da0073e9SAndroid Build Coastguard Worker base = torch.randn(1) 490*da0073e9SAndroid Build Coastguard Worker temp_tensor = torch.quantize_per_tensor(base, 0.1, 10, torch.quint2x4) 491*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 492*da0073e9SAndroid Build Coastguard Worker torch.quantized_max_pool1d(temp_tensor, []) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Workerclass TestPoolingNNDeviceType(NNTestCase): 496*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 497*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 498*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_zero_batch(self, dtype, device): 499*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 10, dtype=dtype, device=device) 500*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.AdaptiveAvgPool1d(5).to(device) 501*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 10, 10, dtype=dtype, device=device) 504*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.AdaptiveAvgPool2d((5, 5)).to(device) 505*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 506*da0073e9SAndroid Build Coastguard Worker 507*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 10, 10, 10, dtype=dtype, device=device) 508*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.AdaptiveAvgPool3d((5, 5, 5)).to(device) 509*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker # The tests are used to verify the functions raises errors for backward propagation 512*da0073e9SAndroid Build Coastguard Worker # when output_size = 0, in adaptive_{avg, max}_pool and its variants. 513*da0073e9SAndroid Build Coastguard Worker # These tests are explicitly written because ErrorInputs does not support backward calls 514*da0073e9SAndroid Build Coastguard Worker # Issue: https://github.com/pytorch/pytorch/issues/78868 515*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 516*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 517*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16, torch.float16) 518*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_empty_output_size(self, dtype, device): 519*da0073e9SAndroid Build Coastguard Worker error_msg = ( 520*da0073e9SAndroid Build Coastguard Worker "Expected grad_output to have non-zero size for non-batch dimensions" 521*da0073e9SAndroid Build Coastguard Worker ) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True) 524*da0073e9SAndroid Build Coastguard Worker input = make_arg((1, 64, 10, 9)) 525*da0073e9SAndroid Build Coastguard Worker output_size = 0 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker fns = ( 528*da0073e9SAndroid Build Coastguard Worker nn.functional.adaptive_avg_pool2d, 529*da0073e9SAndroid Build Coastguard Worker nn.functional.adaptive_avg_pool3d, 530*da0073e9SAndroid Build Coastguard Worker nn.functional.adaptive_max_pool2d, 531*da0073e9SAndroid Build Coastguard Worker nn.functional.adaptive_max_pool3d, 532*da0073e9SAndroid Build Coastguard Worker ) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker for fn in fns: 535*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 536*da0073e9SAndroid Build Coastguard Worker fn(input, output_size).sum().backward() 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker fns2 = ( 539*da0073e9SAndroid Build Coastguard Worker nn.functional.adaptive_avg_pool1d, 540*da0073e9SAndroid Build Coastguard Worker nn.functional.adaptive_max_pool1d, 541*da0073e9SAndroid Build Coastguard Worker ) 542*da0073e9SAndroid Build Coastguard Worker input2 = make_arg((1, 64)) 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker for fn in fns2: 545*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 546*da0073e9SAndroid Build Coastguard Worker fn(input2, output_size).sum().backward() 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 549*da0073e9SAndroid Build Coastguard Worker def test_FractionalMaxPool2d_zero_batch(self, device): 550*da0073e9SAndroid Build Coastguard Worker mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5)) 551*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 16, 50, 32, device=device) 552*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected input"): 555*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 0, 50, 32, device=device) 556*da0073e9SAndroid Build Coastguard Worker mod(inp) 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 559*da0073e9SAndroid Build Coastguard Worker def test_FractionalMaxPool3d_zero_batch(self, device): 560*da0073e9SAndroid Build Coastguard Worker mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device) 561*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 16, 50, 32, 32, device=device) 562*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected input"): 565*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 0, 50, 32, 32, device=device) 566*da0073e9SAndroid Build Coastguard Worker mod(inp) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 569*da0073e9SAndroid Build Coastguard Worker def test_FractionalMaxPool2d_zero_out_size(self, device): 570*da0073e9SAndroid Build Coastguard Worker mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1]) 571*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([16, 50, 32, 32], device=device) 572*da0073e9SAndroid Build Coastguard Worker out = mod(inp) 573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device)) 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 576*da0073e9SAndroid Build Coastguard Worker def test_FractionalMaxPool3d_zero_out_size(self, device): 577*da0073e9SAndroid Build Coastguard Worker mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1]) 578*da0073e9SAndroid Build Coastguard Worker inp = torch.rand([16, 50, 32, 32], device=device) 579*da0073e9SAndroid Build Coastguard Worker out = mod(inp) 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device)) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 583*da0073e9SAndroid Build Coastguard Worker def test_FractionalMaxPool2d_zero_samples(self, device): 584*da0073e9SAndroid Build Coastguard Worker samples = torch.rand([0, 16, 2], device=device) 585*da0073e9SAndroid Build Coastguard Worker mod = nn.FractionalMaxPool2d( 586*da0073e9SAndroid Build Coastguard Worker [2, 2], output_size=[1, 1], _random_samples=samples 587*da0073e9SAndroid Build Coastguard Worker ) 588*da0073e9SAndroid Build Coastguard Worker inp = torch.randn([0, 16, 32, 32], device=device) 589*da0073e9SAndroid Build Coastguard Worker out = mod(inp) 590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.empty((0, 16, 1, 1), device=device)) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker inp1 = torch.randn([1, 16, 32, 32], device=device) 593*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"): 594*da0073e9SAndroid Build Coastguard Worker out1 = mod(inp1) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 597*da0073e9SAndroid Build Coastguard Worker def test_FractionalMaxPool3d_zero_samples(self, device): 598*da0073e9SAndroid Build Coastguard Worker samples = torch.rand([0, 16, 3], device=device) 599*da0073e9SAndroid Build Coastguard Worker mod = nn.FractionalMaxPool3d( 600*da0073e9SAndroid Build Coastguard Worker [3, 2, 2], output_size=[1, 1, 1], _random_samples=samples 601*da0073e9SAndroid Build Coastguard Worker ) 602*da0073e9SAndroid Build Coastguard Worker inp = torch.randn([0, 16, 50, 32, 32], device=device) 603*da0073e9SAndroid Build Coastguard Worker out = mod(inp) 604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.empty((0, 16, 1, 1, 1), device=device)) 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker inp1 = torch.randn([1, 16, 50, 32, 32], device=device) 607*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expect _random_samples"): 608*da0073e9SAndroid Build Coastguard Worker out1 = mod(inp1) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 611*da0073e9SAndroid Build Coastguard Worker def test_MaxPool_zero_batch_dim(self, device): 612*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 16, 50, device=device) 613*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.MaxPool1d(3, stride=2).to(device) 614*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker # 1D is supposed to be okay with 0 numel() inputs so dont test 617*da0073e9SAndroid Build Coastguard Worker # error raising for that case. 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 16, 50, 32, device=device) 620*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.MaxPool2d(3, stride=2).to(device) 621*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected"): 624*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 0, 50, 32, device=device) 625*da0073e9SAndroid Build Coastguard Worker mod(inp) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 16, 50, 44, 31, device=device) 628*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.MaxPool3d(3, stride=2).to(device) 629*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected"): 632*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(1, 0, 50, 44, 31, device=device) 633*da0073e9SAndroid Build Coastguard Worker mod(inp) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 636*da0073e9SAndroid Build Coastguard Worker def test_MaxUnpool_zero_batch_dim(self, device): 637*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool1d(2, stride=2, return_indices=True).to(device) 638*da0073e9SAndroid Build Coastguard Worker unpool = torch.nn.MaxUnpool1d(2, stride=2).to(device) 639*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 10, 10, requires_grad=True, device=device) 640*da0073e9SAndroid Build Coastguard Worker output, indices = pool(inp) 641*da0073e9SAndroid Build Coastguard Worker output.requires_grad_(True) 642*da0073e9SAndroid Build Coastguard Worker unpool_out = unpool(output, indices) 643*da0073e9SAndroid Build Coastguard Worker unpool_out.sum().backward() 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.grad, torch.zeros_like(inp)) 646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpool_out, torch.zeros_like(unpool_out)) 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool2d(2, stride=2, return_indices=True).to(device) 649*da0073e9SAndroid Build Coastguard Worker unpool = torch.nn.MaxUnpool2d(2, stride=2).to(device) 650*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 10, 10, 10, requires_grad=True, device=device) 651*da0073e9SAndroid Build Coastguard Worker output, indices = pool(inp) 652*da0073e9SAndroid Build Coastguard Worker unpool_out = unpool(output, indices) 653*da0073e9SAndroid Build Coastguard Worker unpool_out.sum().backward() 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.grad, torch.zeros_like(inp)) 656*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpool_out, torch.zeros_like(unpool_out)) 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool3d(2, stride=2, return_indices=True).to(device) 659*da0073e9SAndroid Build Coastguard Worker unpool = torch.nn.MaxUnpool3d(2, stride=2).to(device) 660*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 10, 10, 10, 10, requires_grad=True, device=device) 661*da0073e9SAndroid Build Coastguard Worker output, indices = pool(inp) 662*da0073e9SAndroid Build Coastguard Worker output.requires_grad_(True) 663*da0073e9SAndroid Build Coastguard Worker unpool_out = unpool(output, indices) 664*da0073e9SAndroid Build Coastguard Worker unpool_out.sum().backward() 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp.grad, torch.zeros_like(inp)) 667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpool_out, torch.zeros_like(unpool_out)) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker @slowTest 670*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 671*da0073e9SAndroid Build Coastguard Worker @skipCUDAIfRocm 672*da0073e9SAndroid Build Coastguard Worker @parametrize_test( 673*da0073e9SAndroid Build Coastguard Worker "module_name,module_size,output_size,test_index,should_error", 674*da0073e9SAndroid Build Coastguard Worker [ 675*da0073e9SAndroid Build Coastguard Worker # Some tests are failing in trunk https://github.com/pytorch/pytorch/issues/103854 676*da0073e9SAndroid Build Coastguard Worker subtest( 677*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), -1, True), 678*da0073e9SAndroid Build Coastguard Worker name="case1", 679*da0073e9SAndroid Build Coastguard Worker ), 680*da0073e9SAndroid Build Coastguard Worker subtest( 681*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), 2 * 2 * 4 * 5, True), 682*da0073e9SAndroid Build Coastguard Worker name="case2", 683*da0073e9SAndroid Build Coastguard Worker ), 684*da0073e9SAndroid Build Coastguard Worker subtest( 685*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool2d", (2, 2), (1, 3, 4, 5), (2 * 2 * 4 * 5) - 1, False), 686*da0073e9SAndroid Build Coastguard Worker name="case3", 687*da0073e9SAndroid Build Coastguard Worker ), 688*da0073e9SAndroid Build Coastguard Worker subtest( 689*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool2d", (2, 3), (2, 1, 4, 2), 2 * 3 * 4 * 2, True), 690*da0073e9SAndroid Build Coastguard Worker name="case4", 691*da0073e9SAndroid Build Coastguard Worker ), 692*da0073e9SAndroid Build Coastguard Worker subtest( 693*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool2d", (2, 3), (2, 1, 4, 2), (2 * 3 * 4 * 2) - 1, False), 694*da0073e9SAndroid Build Coastguard Worker name="case5", 695*da0073e9SAndroid Build Coastguard Worker ), 696*da0073e9SAndroid Build Coastguard Worker subtest( 697*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool3d", (2, 2, 2), (1, 3, 4, 5), -1, True), 698*da0073e9SAndroid Build Coastguard Worker name="case6", 699*da0073e9SAndroid Build Coastguard Worker ), 700*da0073e9SAndroid Build Coastguard Worker subtest( 701*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool3d", (2, 2, 2), (1, 3, 4, 5), 2 * 2 * 2 * 3 * 4 * 5, True), 702*da0073e9SAndroid Build Coastguard Worker name="case7", 703*da0073e9SAndroid Build Coastguard Worker ), 704*da0073e9SAndroid Build Coastguard Worker subtest( 705*da0073e9SAndroid Build Coastguard Worker ( 706*da0073e9SAndroid Build Coastguard Worker "MaxUnpool3d", 707*da0073e9SAndroid Build Coastguard Worker (2, 2, 2), 708*da0073e9SAndroid Build Coastguard Worker (1, 3, 4, 5), 709*da0073e9SAndroid Build Coastguard Worker (2 * 2 * 2 * 3 * 4 * 5) - 1, 710*da0073e9SAndroid Build Coastguard Worker False, 711*da0073e9SAndroid Build Coastguard Worker ), 712*da0073e9SAndroid Build Coastguard Worker name="case8", 713*da0073e9SAndroid Build Coastguard Worker ), 714*da0073e9SAndroid Build Coastguard Worker subtest( 715*da0073e9SAndroid Build Coastguard Worker ("MaxUnpool3d", (2, 2, 2), (2, 3, 4, 1), 2 * 2 * 2 * 3 * 4 * 1, True), 716*da0073e9SAndroid Build Coastguard Worker name="case9", 717*da0073e9SAndroid Build Coastguard Worker ), 718*da0073e9SAndroid Build Coastguard Worker subtest( 719*da0073e9SAndroid Build Coastguard Worker ( 720*da0073e9SAndroid Build Coastguard Worker "MaxUnpool3d", 721*da0073e9SAndroid Build Coastguard Worker (2, 2, 2), 722*da0073e9SAndroid Build Coastguard Worker (2, 3, 4, 1), 723*da0073e9SAndroid Build Coastguard Worker (2 * 2 * 2 * 3 * 4 * 1) - 1, 724*da0073e9SAndroid Build Coastguard Worker False, 725*da0073e9SAndroid Build Coastguard Worker ), 726*da0073e9SAndroid Build Coastguard Worker name="case10", 727*da0073e9SAndroid Build Coastguard Worker ), 728*da0073e9SAndroid Build Coastguard Worker ], 729*da0073e9SAndroid Build Coastguard Worker ) 730*da0073e9SAndroid Build Coastguard Worker def test_MaxUnpool_index_errors( 731*da0073e9SAndroid Build Coastguard Worker self, device, module_name, module_size, output_size, test_index, should_error 732*da0073e9SAndroid Build Coastguard Worker ): 733*da0073e9SAndroid Build Coastguard Worker # NOTE: CUDA tests need to be run in a subprocess because they cause device asserts 734*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == "cuda": 735*da0073e9SAndroid Build Coastguard Worker error_msgs = { 736*da0073e9SAndroid Build Coastguard Worker "MaxUnpool2d": r"Assertion `maxind >= 0 && maxind < outputImageSize` failed", 737*da0073e9SAndroid Build Coastguard Worker "MaxUnpool3d": r"Assertion `index >= 0 && index < outputImageSize` failed", 738*da0073e9SAndroid Build Coastguard Worker } 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker script = f""" 741*da0073e9SAndroid Build Coastguard Workerimport torch 742*da0073e9SAndroid Build Coastguard Workerunpool = torch.nn.{module_name}({module_size}).to('{device}') 743*da0073e9SAndroid Build Coastguard Workeroutput = torch.rand({output_size}, dtype=torch.float32, device='{device}') 744*da0073e9SAndroid Build Coastguard Workerindices = torch.zeros({output_size}, dtype=torch.int64, device='{device}') 745*da0073e9SAndroid Build Coastguard Workerindices.flatten()[0] = {test_index} 746*da0073e9SAndroid Build Coastguard Workerunpool(output, indices) 747*da0073e9SAndroid Build Coastguard Workertorch.cuda.synchronize() 748*da0073e9SAndroid Build Coastguard Worker""" 749*da0073e9SAndroid Build Coastguard Worker p = subprocess.run( 750*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-c", script], 751*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)), 752*da0073e9SAndroid Build Coastguard Worker capture_output=True, 753*da0073e9SAndroid Build Coastguard Worker text=True, 754*da0073e9SAndroid Build Coastguard Worker ) 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker output = p.stdout + "\n" + p.stderr 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker error_msg = error_msgs[module_name] 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker if should_error: 761*da0073e9SAndroid Build Coastguard Worker self.assertIn(error_msg, output, "The expected error was not found") 762*da0073e9SAndroid Build Coastguard Worker else: 763*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("Error", output, "Should not have produced an error") 764*da0073e9SAndroid Build Coastguard Worker else: 765*da0073e9SAndroid Build Coastguard Worker module_class = getattr(torch.nn, module_name) 766*da0073e9SAndroid Build Coastguard Worker unpool = module_class(module_size).to(device) 767*da0073e9SAndroid Build Coastguard Worker output = torch.rand(output_size, dtype=torch.float32, device=device) 768*da0073e9SAndroid Build Coastguard Worker indices = torch.zeros(output_size, dtype=torch.int64, device=device) 769*da0073e9SAndroid Build Coastguard Worker indices.flatten()[0] = test_index 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker if should_error: 772*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 773*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"Found an invalid max index:" 774*da0073e9SAndroid Build Coastguard Worker ): 775*da0073e9SAndroid Build Coastguard Worker unpool(output, indices) 776*da0073e9SAndroid Build Coastguard Worker else: 777*da0073e9SAndroid Build Coastguard Worker unpool(output, indices) 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 780*da0073e9SAndroid Build Coastguard Worker def test_AdaptiveMaxPool_zero_batch_dim(self, device): 781*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 16, 50, device=device) 782*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.AdaptiveMaxPool1d(3).to(device) 783*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected"): 786*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 0, 50, device=device) 787*da0073e9SAndroid Build Coastguard Worker mod(inp) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 16, 50, 32, device=device) 790*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.AdaptiveMaxPool2d(3).to(device) 791*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected"): 794*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 0, 50, 32, device=device) 795*da0073e9SAndroid Build Coastguard Worker mod(inp) 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(0, 16, 50, 44, 31, device=device) 798*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.AdaptiveMaxPool3d(3).to(device) 799*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, mod, inp, check_size=False) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected"): 802*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(1, 0, 50, 44, 31, device=device) 803*da0073e9SAndroid Build Coastguard Worker mod(inp) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 806*da0073e9SAndroid Build Coastguard Worker def test_AvgPool2d_empty(self, device): 807*da0073e9SAndroid Build Coastguard Worker avgpool = torch.nn.AvgPool2d(3, stride=2).to(device) 808*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(0, 16, 20, 32, device=device) 809*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, avgpool, inp, check_size=False) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker clast_inp = torch.randn(0, 16, 20, 32, device=device).contiguous( 812*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last 813*da0073e9SAndroid Build Coastguard Worker ) 814*da0073e9SAndroid Build Coastguard Worker _test_module_empty_input(self, avgpool, clast_inp, check_size=False) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker # test with empty non-batch input 817*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "3D or 4D"): 818*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(16, 0, 20, 32, device=device) 819*da0073e9SAndroid Build Coastguard Worker avgpool(inp) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker def test_pooling_shape(self, device): 822*da0073e9SAndroid Build Coastguard Worker """Test the output shape calculation for pooling functions""" 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker # Checks output shape against expected for 1D, 2D and 3D 825*da0073e9SAndroid Build Coastguard Worker def check(expected_out_shape, sizes, *args, **kwargs): 826*da0073e9SAndroid Build Coastguard Worker for kernel in ["max", "avg"]: 827*da0073e9SAndroid Build Coastguard Worker for i in [1, 2, 3]: 828*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.nn.functional, f"{kernel}_pool{i}d"): 829*da0073e9SAndroid Build Coastguard Worker op = getattr(torch.nn.functional, f"{kernel}_pool{i}d") 830*da0073e9SAndroid Build Coastguard Worker t = torch.randn(sizes[: i + 2], device=device) 831*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 832*da0073e9SAndroid Build Coastguard Worker op(t, *args, **kwargs).shape, expected_out_shape[: i + 2] 833*da0073e9SAndroid Build Coastguard Worker ) 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker check( 836*da0073e9SAndroid Build Coastguard Worker (1, 1, 3, 3, 4), 837*da0073e9SAndroid Build Coastguard Worker (1, 1, 5, 6, 7), 838*da0073e9SAndroid Build Coastguard Worker kernel_size=1, 839*da0073e9SAndroid Build Coastguard Worker stride=2, 840*da0073e9SAndroid Build Coastguard Worker padding=0, 841*da0073e9SAndroid Build Coastguard Worker ceil_mode=True, 842*da0073e9SAndroid Build Coastguard Worker ) 843*da0073e9SAndroid Build Coastguard Worker check( 844*da0073e9SAndroid Build Coastguard Worker (1, 1, 2, 3, 3), 845*da0073e9SAndroid Build Coastguard Worker (1, 1, 3, 4, 5), 846*da0073e9SAndroid Build Coastguard Worker kernel_size=2, 847*da0073e9SAndroid Build Coastguard Worker stride=2, 848*da0073e9SAndroid Build Coastguard Worker padding=1, 849*da0073e9SAndroid Build Coastguard Worker ceil_mode=False, 850*da0073e9SAndroid Build Coastguard Worker ) 851*da0073e9SAndroid Build Coastguard Worker check( 852*da0073e9SAndroid Build Coastguard Worker (1, 1, 2, 3, 3), 853*da0073e9SAndroid Build Coastguard Worker (1, 1, 3, 4, 5), 854*da0073e9SAndroid Build Coastguard Worker kernel_size=2, 855*da0073e9SAndroid Build Coastguard Worker stride=2, 856*da0073e9SAndroid Build Coastguard Worker padding=1, 857*da0073e9SAndroid Build Coastguard Worker ceil_mode=True, 858*da0073e9SAndroid Build Coastguard Worker ) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker # Test case from issue https://github.com/pytorch/pytorch/issues/45357 861*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, 6, 7, device=device) 862*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.max_pool2d( 863*da0073e9SAndroid Build Coastguard Worker x, 1, stride=(2, 2), padding=0, ceil_mode=True 864*da0073e9SAndroid Build Coastguard Worker ) 865*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.size(), (1, 1, 3, 4)) 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # TODO: fix on XLA 868*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pool2d_output_size_one(self, device): 869*da0073e9SAndroid Build Coastguard Worker def helper(size, memory_format): 870*da0073e9SAndroid Build Coastguard Worker x = torch.randint( 871*da0073e9SAndroid Build Coastguard Worker 1, 10, size, dtype=torch.float, device=device, requires_grad=True 872*da0073e9SAndroid Build Coastguard Worker ) 873*da0073e9SAndroid Build Coastguard Worker if memory_format == "non_contiguous": 874*da0073e9SAndroid Build Coastguard Worker x = x[::2, ::2, ::2, ::2] 875*da0073e9SAndroid Build Coastguard Worker else: 876*da0073e9SAndroid Build Coastguard Worker x = x.to(memory_format=memory_format) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker net = torch.nn.AdaptiveAvgPool2d((1, 1)) 879*da0073e9SAndroid Build Coastguard Worker out = net(x) 880*da0073e9SAndroid Build Coastguard Worker ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1)) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker out.sum().backward() # make sure it doesn't crash 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 885*da0073e9SAndroid Build Coastguard Worker if memory_format == torch.channels_last: 886*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 887*da0073e9SAndroid Build Coastguard Worker c = out.size(1) 888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.stride(), [c, 1, c, c]) 889*da0073e9SAndroid Build Coastguard Worker else: 890*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous()) 891*da0073e9SAndroid Build Coastguard Worker c = out.size(1) 892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.stride(), [c, 1, 1, 1]) 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker for mf in (torch.contiguous_format, torch.channels_last, "non_contiguous"): 895*da0073e9SAndroid Build Coastguard Worker helper((2, 3, 6, 6), mf) 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 898*da0073e9SAndroid Build Coastguard Worker def test_adaptive_avg_pool3d_output_size_one(self, device): 899*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 900*da0073e9SAndroid Build Coastguard Worker (2, 3, 6, 6, 6), dtype=torch.float, device=device, requires_grad=True 901*da0073e9SAndroid Build Coastguard Worker ) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker net = torch.nn.AdaptiveAvgPool3d(1) 904*da0073e9SAndroid Build Coastguard Worker out = net(x) 905*da0073e9SAndroid Build Coastguard Worker ref_out = x.contiguous().mean((-1, -2, -3)).view(out.shape) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker out.sum().backward() # make sure it doesn't crash 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 910*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous()) 911*da0073e9SAndroid Build Coastguard Worker c = out.size(1) 912*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.stride(), [c, 1, 1, 1, 1]) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # Runtime Error not raised for meta 915*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 916*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long) 917*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_no_suppot_input(self, device, dtype): 918*da0073e9SAndroid Build Coastguard Worker for numel in (2, 3): 919*da0073e9SAndroid Build Coastguard Worker for pool_type in ("Max", "Avg"): 920*da0073e9SAndroid Build Coastguard Worker cls_name = f"Adaptive{pool_type}Pool{numel}d" 921*da0073e9SAndroid Build Coastguard Worker module_cls = getattr(nn, cls_name) 922*da0073e9SAndroid Build Coastguard Worker output_size = (2,) * numel 923*da0073e9SAndroid Build Coastguard Worker module = module_cls(output_size) 924*da0073e9SAndroid Build Coastguard Worker input = torch.randn((4,) * (numel + 1), device=device).to(dtype) 925*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not implemented"): 926*da0073e9SAndroid Build Coastguard Worker output = module(input) 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 929*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 930*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 931*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 932*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_nhwc(self, device, dtype): 933*da0073e9SAndroid Build Coastguard Worker def helper( 934*da0073e9SAndroid Build Coastguard Worker n, 935*da0073e9SAndroid Build Coastguard Worker c, 936*da0073e9SAndroid Build Coastguard Worker h, 937*da0073e9SAndroid Build Coastguard Worker w, 938*da0073e9SAndroid Build Coastguard Worker kernel_size, 939*da0073e9SAndroid Build Coastguard Worker stride=None, 940*da0073e9SAndroid Build Coastguard Worker count_include_pad=True, 941*da0073e9SAndroid Build Coastguard Worker divisor_override=None, 942*da0073e9SAndroid Build Coastguard Worker padding=0, 943*da0073e9SAndroid Build Coastguard Worker ): 944*da0073e9SAndroid Build Coastguard Worker if stride is None: 945*da0073e9SAndroid Build Coastguard Worker stride = kernel_size 946*da0073e9SAndroid Build Coastguard Worker input = torch.randn(n, c, h, w, dtype=dtype, device=device) 947*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).requires_grad_() 948*da0073e9SAndroid Build Coastguard Worker grad = torch.randn( 949*da0073e9SAndroid Build Coastguard Worker n, 950*da0073e9SAndroid Build Coastguard Worker c, 951*da0073e9SAndroid Build Coastguard Worker (h - kernel_size) // stride + 1, 952*da0073e9SAndroid Build Coastguard Worker (w - kernel_size) // stride + 1, 953*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 954*da0073e9SAndroid Build Coastguard Worker device=device, 955*da0073e9SAndroid Build Coastguard Worker ) 956*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.AvgPool2d( 957*da0073e9SAndroid Build Coastguard Worker kernel_size, 958*da0073e9SAndroid Build Coastguard Worker stride=stride, 959*da0073e9SAndroid Build Coastguard Worker count_include_pad=count_include_pad, 960*da0073e9SAndroid Build Coastguard Worker divisor_override=divisor_override, 961*da0073e9SAndroid Build Coastguard Worker ).to(device) 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 964*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 965*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.AvgPool2d( 966*da0073e9SAndroid Build Coastguard Worker kernel_size, 967*da0073e9SAndroid Build Coastguard Worker stride=stride, 968*da0073e9SAndroid Build Coastguard Worker count_include_pad=count_include_pad, 969*da0073e9SAndroid Build Coastguard Worker divisor_override=divisor_override, 970*da0073e9SAndroid Build Coastguard Worker ).to(device) 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker out = pool(input) 973*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 974*da0073e9SAndroid Build Coastguard Worker ref_out = ref_pool(ref_input) 975*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 978*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 3) 983*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 3, count_include_pad=False, padding=1) 984*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 3, count_include_pad=False, padding=2, stride=2) 985*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 3, divisor_override=42) 986*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 7) 987*da0073e9SAndroid Build Coastguard Worker # ROCm 16GB MI25 hits OOM error. Clear caching allocator prior to running large subtest. 988*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_ROCM and "cuda" in device: 989*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 990*da0073e9SAndroid Build Coastguard Worker helper(200, 512, 28, 28, 2) 991*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 7, 7, 3, stride=1) 992*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 7, 7, 3, padding=2, stride=1) 993*da0073e9SAndroid Build Coastguard Worker helper(10, 512, 31, 31, 3, stride=2) 994*da0073e9SAndroid Build Coastguard Worker helper(1, 129, 8, 8, 3, stride=2) 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker @onlyCPU 997*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 998*da0073e9SAndroid Build Coastguard Worker def test_max_pool1d_corner_cases(self, device, dtype): 999*da0073e9SAndroid Build Coastguard Worker def check(x, args, expected): 1000*da0073e9SAndroid Build Coastguard Worker model = torch.nn.MaxPool1d(*args) 1001*da0073e9SAndroid Build Coastguard Worker if isinstance(x, list): 1002*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(x, device=device, dtype=dtype) 1003*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(expected, device=device, dtype=dtype) 1004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(x), expected) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) 1007*da0073e9SAndroid Build Coastguard Worker check([[1]], (1, None, 0, 1, False, False), [[1]]) 1008*da0073e9SAndroid Build Coastguard Worker check([[1]], (2, None, 1, 2, False, False), [[float("-inf")]]) 1009*da0073e9SAndroid Build Coastguard Worker check( 1010*da0073e9SAndroid Build Coastguard Worker [[1], [1]], 1011*da0073e9SAndroid Build Coastguard Worker (2, None, 1, 2, False, False), 1012*da0073e9SAndroid Build Coastguard Worker [[float("-inf")], [float("-inf")]], 1013*da0073e9SAndroid Build Coastguard Worker ) 1014*da0073e9SAndroid Build Coastguard Worker check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]]) 1015*da0073e9SAndroid Build Coastguard Worker check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]]) 1016*da0073e9SAndroid Build Coastguard Worker 1017*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1018*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1019*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("OOMs https://github.com/pytorch/pytorch/issues/111320") 1020*da0073e9SAndroid Build Coastguard Worker def test_max_pool1d(self, device, dtype): 1021*da0073e9SAndroid Build Coastguard Worker # FIXME For now compare against max_pool1d with indices 1022*da0073e9SAndroid Build Coastguard Worker def check(x, *args, **kwargs): 1023*da0073e9SAndroid Build Coastguard Worker model = torch.nn.MaxPool1d(*args, **kwargs) 1024*da0073e9SAndroid Build Coastguard Worker ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True) 1025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(x), ref_model(x)[0]) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker sizes = [random.sample(range(8, 128), 3) for _ in range(3)] 1028*da0073e9SAndroid Build Coastguard Worker kernel_sizes = random.sample(range(1, 5), 3) 1029*da0073e9SAndroid Build Coastguard Worker strides = random.sample(range(1, 5), 3) 1030*da0073e9SAndroid Build Coastguard Worker dilations = random.sample(range(1, 5), 3) 1031*da0073e9SAndroid Build Coastguard Worker ceil_modes = [True, False] 1032*da0073e9SAndroid Build Coastguard Worker 1033*da0073e9SAndroid Build Coastguard Worker for size, kernel_size, stride, dilation, ceil_mode in itertools.product( 1034*da0073e9SAndroid Build Coastguard Worker sizes, kernel_sizes, strides, dilations, ceil_modes 1035*da0073e9SAndroid Build Coastguard Worker ): 1036*da0073e9SAndroid Build Coastguard Worker padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1) 1037*da0073e9SAndroid Build Coastguard Worker check( 1038*da0073e9SAndroid Build Coastguard Worker torch.randn(size, device=device, dtype=dtype), 1039*da0073e9SAndroid Build Coastguard Worker kernel_size, 1040*da0073e9SAndroid Build Coastguard Worker stride, 1041*da0073e9SAndroid Build Coastguard Worker padding, 1042*da0073e9SAndroid Build Coastguard Worker dilation, 1043*da0073e9SAndroid Build Coastguard Worker ceil_mode=ceil_mode, 1044*da0073e9SAndroid Build Coastguard Worker ) 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker # Non-contiguous test 1047*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(5, 151, 33, device=device, dtype=dtype)[::2, ::3, ::2] 1048*da0073e9SAndroid Build Coastguard Worker check(tensor, 3, 2, 1, 2, ceil_mode=True) 1049*da0073e9SAndroid Build Coastguard Worker check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1052*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 1053*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d(self, device): 1054*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, ks): 1055*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 1056*da0073e9SAndroid Build Coastguard Worker n, c, h, w, device="cuda", dtype=torch.float, requires_grad=True 1057*da0073e9SAndroid Build Coastguard Worker ) 1058*da0073e9SAndroid Build Coastguard Worker ref_x = x.detach().clone().cpu().requires_grad_() 1059*da0073e9SAndroid Build Coastguard Worker 1060*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool2d(kernel_size=ks) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker y = pool(x) 1063*da0073e9SAndroid Build Coastguard Worker ref_y = pool(ref_x) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 1066*da0073e9SAndroid Build Coastguard Worker ref_y.sum().backward() 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 1069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, ref_x.grad) 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, ks=2) 1072*da0073e9SAndroid Build Coastguard Worker helper(1, 100000, 32, 32, ks=4) 1073*da0073e9SAndroid Build Coastguard Worker helper(1, 100000, 1, 4, ks=(1, 4)) # test for max_pool1d 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1076*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) 1077*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 1078*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 1079*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d_nhwc(self, device, dtype): 1080*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, kernel_size, stride=None): 1081*da0073e9SAndroid Build Coastguard Worker if stride is None: 1082*da0073e9SAndroid Build Coastguard Worker stride = kernel_size 1083*da0073e9SAndroid Build Coastguard Worker input = torch.randn(n, c, h, w, dtype=dtype, device=device) 1084*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last).requires_grad_() 1085*da0073e9SAndroid Build Coastguard Worker grad = torch.randn( 1086*da0073e9SAndroid Build Coastguard Worker n, 1087*da0073e9SAndroid Build Coastguard Worker c, 1088*da0073e9SAndroid Build Coastguard Worker (h - kernel_size) // stride + 1, 1089*da0073e9SAndroid Build Coastguard Worker (w - kernel_size) // stride + 1, 1090*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1091*da0073e9SAndroid Build Coastguard Worker device=device, 1092*da0073e9SAndroid Build Coastguard Worker ) 1093*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to( 1094*da0073e9SAndroid Build Coastguard Worker device 1095*da0073e9SAndroid Build Coastguard Worker ) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 1098*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 1099*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to( 1100*da0073e9SAndroid Build Coastguard Worker device 1101*da0073e9SAndroid Build Coastguard Worker ) 1102*da0073e9SAndroid Build Coastguard Worker 1103*da0073e9SAndroid Build Coastguard Worker out, ind = pool(input) 1104*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 1105*da0073e9SAndroid Build Coastguard Worker ref_out, ref_ind = ref_pool(ref_input) 1106*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 1109*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 1110*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last)) 1111*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_ind.is_contiguous()) 1112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ind, ref_ind) 1114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 7) 1117*da0073e9SAndroid Build Coastguard Worker helper(200, 512, 28, 28, 2) 1118*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 7, 7, 3, stride=1) 1119*da0073e9SAndroid Build Coastguard Worker helper(10, 512, 31, 31, 3, stride=2) 1120*da0073e9SAndroid Build Coastguard Worker helper(1, 129, 8, 8, 3, stride=2) 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1123*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int32, torch.int64) 1124*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d_corner_cases(self, device, dtype): 1125*da0073e9SAndroid Build Coastguard Worker def check(x, args, expected, memory_format): 1126*da0073e9SAndroid Build Coastguard Worker model = torch.nn.MaxPool2d(*args) 1127*da0073e9SAndroid Build Coastguard Worker if isinstance(x, list): 1128*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(x, device=device, dtype=dtype).to( 1129*da0073e9SAndroid Build Coastguard Worker memory_format=memory_format 1130*da0073e9SAndroid Build Coastguard Worker ) 1131*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(expected, device=device, dtype=dtype).to( 1132*da0073e9SAndroid Build Coastguard Worker memory_format=memory_format 1133*da0073e9SAndroid Build Coastguard Worker ) 1134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model(x), expected) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) 1137*da0073e9SAndroid Build Coastguard Worker check( 1138*da0073e9SAndroid Build Coastguard Worker [[[[-1, -2], [-3, -4]]]], 1139*da0073e9SAndroid Build Coastguard Worker (2, 2, 1, 2, False, True), 1140*da0073e9SAndroid Build Coastguard Worker [[[[-4, -4], [-4, -4]]]], 1141*da0073e9SAndroid Build Coastguard Worker torch.contiguous_format, 1142*da0073e9SAndroid Build Coastguard Worker ) 1143*da0073e9SAndroid Build Coastguard Worker check( 1144*da0073e9SAndroid Build Coastguard Worker [[[[-1, -2], [-3, -4]]]], 1145*da0073e9SAndroid Build Coastguard Worker (2, 2, 1, 2, False, True), 1146*da0073e9SAndroid Build Coastguard Worker [[[[-4, -4], [-4, -4]]]], 1147*da0073e9SAndroid Build Coastguard Worker torch.channels_last, 1148*da0073e9SAndroid Build Coastguard Worker ) 1149*da0073e9SAndroid Build Coastguard Worker 1150*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1151*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) 1152*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 1153*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 1154*da0073e9SAndroid Build Coastguard Worker def test_max_pool3d_ndhwc(self, device, dtype): 1155*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, d, kernel_size, stride=None): 1156*da0073e9SAndroid Build Coastguard Worker batch = n 1157*da0073e9SAndroid Build Coastguard Worker if not batch: 1158*da0073e9SAndroid Build Coastguard Worker batch = 1 1159*da0073e9SAndroid Build Coastguard Worker input = torch.randn(batch, c, d, h, w, dtype=dtype, device=device) 1160*da0073e9SAndroid Build Coastguard Worker input = input.contiguous( 1161*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last_3d 1162*da0073e9SAndroid Build Coastguard Worker ).requires_grad_() 1163*da0073e9SAndroid Build Coastguard Worker if not n: 1164*da0073e9SAndroid Build Coastguard Worker input = input.squeeze(0).detach().clone().requires_grad_() 1165*da0073e9SAndroid Build Coastguard Worker if isinstance(kernel_size, int): 1166*da0073e9SAndroid Build Coastguard Worker kernel_size = [kernel_size] * 3 1167*da0073e9SAndroid Build Coastguard Worker if stride is None: 1168*da0073e9SAndroid Build Coastguard Worker stride = kernel_size 1169*da0073e9SAndroid Build Coastguard Worker elif isinstance(stride, int): 1170*da0073e9SAndroid Build Coastguard Worker stride = [stride] * 3 1171*da0073e9SAndroid Build Coastguard Worker grad = torch.randn( 1172*da0073e9SAndroid Build Coastguard Worker batch, 1173*da0073e9SAndroid Build Coastguard Worker c, 1174*da0073e9SAndroid Build Coastguard Worker (d - kernel_size[0]) // stride[0] + 1, 1175*da0073e9SAndroid Build Coastguard Worker (h - kernel_size[1]) // stride[1] + 1, 1176*da0073e9SAndroid Build Coastguard Worker (w - kernel_size[2]) // stride[2] + 1, 1177*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1178*da0073e9SAndroid Build Coastguard Worker device=device, 1179*da0073e9SAndroid Build Coastguard Worker ) 1180*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=torch.channels_last_3d) 1181*da0073e9SAndroid Build Coastguard Worker if not n: 1182*da0073e9SAndroid Build Coastguard Worker grad = grad.squeeze(0) 1183*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to( 1184*da0073e9SAndroid Build Coastguard Worker device 1185*da0073e9SAndroid Build Coastguard Worker ) 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 1188*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 1189*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to( 1190*da0073e9SAndroid Build Coastguard Worker device 1191*da0073e9SAndroid Build Coastguard Worker ) 1192*da0073e9SAndroid Build Coastguard Worker out, ind = pool(input) 1193*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 1194*da0073e9SAndroid Build Coastguard Worker ref_out, ref_ind = ref_pool(ref_input) 1195*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker if len(out.shape) == 4: 1198*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1199*da0073e9SAndroid Build Coastguard Worker out.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d) 1200*da0073e9SAndroid Build Coastguard Worker ) 1201*da0073e9SAndroid Build Coastguard Worker else: 1202*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) 1203*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 1204*da0073e9SAndroid Build Coastguard Worker if len(ind.shape) == 4: 1205*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1206*da0073e9SAndroid Build Coastguard Worker ind.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d) 1207*da0073e9SAndroid Build Coastguard Worker ) 1208*da0073e9SAndroid Build Coastguard Worker else: 1209*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last_3d)) 1210*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_ind.is_contiguous()) 1211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ind, ref_ind) 1213*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 1214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad, atol=0.05, rtol=0.01) 1215*da0073e9SAndroid Build Coastguard Worker else: 1216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 1217*da0073e9SAndroid Build Coastguard Worker 1218*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 8, 7) 1219*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 8, 8, 8, (5, 6, 7)) 1220*da0073e9SAndroid Build Coastguard Worker helper(1, 8, 8, 8, 8, (5, 6, 7)) 1221*da0073e9SAndroid Build Coastguard Worker helper(0, 6, 12, 13, 14, (5, 6, 7)) 1222*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 7, 7, 7, 3, stride=1) 1223*da0073e9SAndroid Build Coastguard Worker helper(10, 128, 19, 19, 19, 3, stride=2) 1224*da0073e9SAndroid Build Coastguard Worker helper(10, 128, 19, 19, 19, (1, 2, 3), stride=2) 1225*da0073e9SAndroid Build Coastguard Worker helper(1, 128, 19, 19, 19, (1, 2, 3), stride=2) 1226*da0073e9SAndroid Build Coastguard Worker helper(0, 128, 19, 19, 19, (1, 2, 3), stride=2) 1227*da0073e9SAndroid Build Coastguard Worker helper(1, 79, 4, 4, 4, 3, stride=2) 1228*da0073e9SAndroid Build Coastguard Worker helper(0, 79, 4, 4, 4, 3, stride=2) 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1231*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.bfloat16) 1232*da0073e9SAndroid Build Coastguard Worker def test_max_pool_bfloat16_half(self, device, dtype): 1233*da0073e9SAndroid Build Coastguard Worker def helper(shape, kernel_size, stride, memory_format, dtype): 1234*da0073e9SAndroid Build Coastguard Worker input = torch.randn(shape, dtype=dtype, device=device) 1235*da0073e9SAndroid Build Coastguard Worker input = input.to(memory_format=memory_format).requires_grad_() 1236*da0073e9SAndroid Build Coastguard Worker if len(shape) == 4: 1237*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to( 1238*da0073e9SAndroid Build Coastguard Worker device 1239*da0073e9SAndroid Build Coastguard Worker ) 1240*da0073e9SAndroid Build Coastguard Worker else: 1241*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to( 1242*da0073e9SAndroid Build Coastguard Worker device 1243*da0073e9SAndroid Build Coastguard Worker ) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().float().requires_grad_(True) 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker out, ind = pool(input) 1248*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 1249*da0073e9SAndroid Build Coastguard Worker out2, ind2 = pool(input2) 1250*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=memory_format)) 1253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, dtype) 1254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.dtype, dtype) 1255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2.to(dtype=dtype)) 1256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ind, ind2) 1257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input2.grad.to(dtype=dtype)) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype) 1260*da0073e9SAndroid Build Coastguard Worker helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype) 1261*da0073e9SAndroid Build Coastguard Worker helper((1, 19, 20, 10), 8, 2, torch.contiguous_format, dtype) 1262*da0073e9SAndroid Build Coastguard Worker helper((1, 19, 20, 10), 8, 2, torch.channels_last, dtype) 1263*da0073e9SAndroid Build Coastguard Worker helper((4, 30, 8, 8), 7, 1, torch.contiguous_format, dtype) 1264*da0073e9SAndroid Build Coastguard Worker helper((4, 65, 8, 8), 7, 1, torch.channels_last, dtype) 1265*da0073e9SAndroid Build Coastguard Worker helper((1, 19, 10, 10, 10), 8, 2, torch.contiguous_format, dtype) 1266*da0073e9SAndroid Build Coastguard Worker helper((1, 19, 10, 9, 14), 8, 2, torch.channels_last_3d, dtype) 1267*da0073e9SAndroid Build Coastguard Worker helper((4, 10, 3, 8, 8), 3, 1, torch.contiguous_format, dtype) 1268*da0073e9SAndroid Build Coastguard Worker helper((4, 10, 8, 8, 8), 7, 1, torch.channels_last_3d, dtype) 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1271*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 1272*da0073e9SAndroid Build Coastguard Worker def test_max_pool2d_indices(self, device): 1273*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, ks): 1274*da0073e9SAndroid Build Coastguard Worker if n is None: 1275*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 1276*da0073e9SAndroid Build Coastguard Worker c, h, w, device="cuda", dtype=torch.float, requires_grad=True 1277*da0073e9SAndroid Build Coastguard Worker ) 1278*da0073e9SAndroid Build Coastguard Worker else: 1279*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 1280*da0073e9SAndroid Build Coastguard Worker n, c, h, w, device="cuda", dtype=torch.float, requires_grad=True 1281*da0073e9SAndroid Build Coastguard Worker ) 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker ref_x = x.detach().clone().cpu().requires_grad_() 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool2d(kernel_size=ks, return_indices=True) 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker y, idx = pool(x) 1288*da0073e9SAndroid Build Coastguard Worker ref_y, ref_idx = pool(ref_x) 1289*da0073e9SAndroid Build Coastguard Worker 1290*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 1291*da0073e9SAndroid Build Coastguard Worker ref_y.sum().backward() 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y) 1294*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1295*da0073e9SAndroid Build Coastguard Worker idx, ref_idx 1296*da0073e9SAndroid Build Coastguard Worker ) # assertEqual implicitly compares shape for tensors 1297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, ref_x.grad) 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker helper(2, 8, 4, 4, ks=2) 1300*da0073e9SAndroid Build Coastguard Worker helper(None, 3, 50, 50, ks=5) 1301*da0073e9SAndroid Build Coastguard Worker 1302*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1303*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.bfloat16) 1304*da0073e9SAndroid Build Coastguard Worker def test_avg_pool2d_reduced_floating(self, device, dtype): 1305*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, kernel_size, stride, memory_format): 1306*da0073e9SAndroid Build Coastguard Worker input = torch.randn(n, c, h, w, dtype=torch.float32, device=device).to( 1307*da0073e9SAndroid Build Coastguard Worker dtype=dtype 1308*da0073e9SAndroid Build Coastguard Worker ) 1309*da0073e9SAndroid Build Coastguard Worker input = input.to(memory_format=memory_format).requires_grad_() 1310*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.AvgPool2d(kernel_size, stride).to(device) 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker input2 = input.detach().clone().float().requires_grad_(True) 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker out = pool(input) 1315*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 1316*da0073e9SAndroid Build Coastguard Worker out2 = pool(input2) 1317*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=memory_format)) 1320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, dtype) 1321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.dtype, dtype) 1322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2.to(dtype=dtype)) 1323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, input2.grad.to(dtype=dtype)) 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker helper(4, 30, 8, 8, 7, 1, torch.contiguous_format) 1326*da0073e9SAndroid Build Coastguard Worker helper(4, 65, 8, 8, 7, 1, torch.channels_last) 1327*da0073e9SAndroid Build Coastguard Worker helper(1, 19, 20, 10, 8, 2, torch.contiguous_format) 1328*da0073e9SAndroid Build Coastguard Worker helper(1, 19, 20, 10, 8, 2, torch.channels_last) 1329*da0073e9SAndroid Build Coastguard Worker 1330*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1331*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pooling_max_nhwc(self, device, dtype): 1332*da0073e9SAndroid Build Coastguard Worker def helper(input_size, output_plane_size, contig): 1333*da0073e9SAndroid Build Coastguard Worker n_plane_dims = len(output_plane_size) 1334*da0073e9SAndroid Build Coastguard Worker mod = ( 1335*da0073e9SAndroid Build Coastguard Worker torch.nn.AdaptiveMaxPool2d 1336*da0073e9SAndroid Build Coastguard Worker if n_plane_dims == 2 1337*da0073e9SAndroid Build Coastguard Worker else torch.nn.AdaptiveMaxPool3d 1338*da0073e9SAndroid Build Coastguard Worker ) 1339*da0073e9SAndroid Build Coastguard Worker channels_last = ( 1340*da0073e9SAndroid Build Coastguard Worker torch.channels_last if n_plane_dims == 2 else torch.channels_last_3d 1341*da0073e9SAndroid Build Coastguard Worker ) 1342*da0073e9SAndroid Build Coastguard Worker output_size = input_size[:2] + output_plane_size 1343*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, input_size, device=device, dtype=dtype) 1344*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=channels_last) 1345*da0073e9SAndroid Build Coastguard Worker grad = torch.randint(1, 10, output_size, device=device, dtype=dtype) 1346*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=channels_last) 1347*da0073e9SAndroid Build Coastguard Worker if not contig: 1348*da0073e9SAndroid Build Coastguard Worker input = input[:, ::2] 1349*da0073e9SAndroid Build Coastguard Worker grad = grad[:, ::2] 1350*da0073e9SAndroid Build Coastguard Worker input.requires_grad_(True) 1351*da0073e9SAndroid Build Coastguard Worker pool = mod(output_plane_size, return_indices=True).to(device) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 1354*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 1355*da0073e9SAndroid Build Coastguard Worker ref_pool = mod(output_plane_size, return_indices=True).to(device) 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker out, ind = pool(input) 1358*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 1359*da0073e9SAndroid Build Coastguard Worker ref_out, ref_ind = ref_pool(ref_input) 1360*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker # channels_last_3d case does not return channels_last_3d outputs 1363*da0073e9SAndroid Build Coastguard Worker if n_plane_dims == 2: 1364*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=channels_last)) 1365*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ind.is_contiguous(memory_format=channels_last)) 1366*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 1367*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_ind.is_contiguous()) 1368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ind, ref_ind) 1370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker for contig in [True, False]: 1373*da0073e9SAndroid Build Coastguard Worker helper((4, 8, 10, 10), (7, 7), contig) 1374*da0073e9SAndroid Build Coastguard Worker helper((4, 8, 9, 14), (5, 8), contig) 1375*da0073e9SAndroid Build Coastguard Worker helper((4, 8, 11, 11), (1, 1), contig) 1376*da0073e9SAndroid Build Coastguard Worker helper((2, 1, 3, 3), (1, 1), contig) 1377*da0073e9SAndroid Build Coastguard Worker helper((4, 8, 10, 10, 10), (7, 7, 7), contig) 1378*da0073e9SAndroid Build Coastguard Worker helper((4, 8, 11, 11, 11), (1, 1, 1), contig) 1379*da0073e9SAndroid Build Coastguard Worker helper((2, 1, 3, 3, 3), (1, 1, 1), contig) 1380*da0073e9SAndroid Build Coastguard Worker 1381*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1382*da0073e9SAndroid Build Coastguard Worker def test_pooling_max_nhwc(self, device, dtype): 1383*da0073e9SAndroid Build Coastguard Worker def helper(n, c, h, w, kernel_size, stride, padding, dilation, contig, device): 1384*da0073e9SAndroid Build Coastguard Worker output_height = math.floor( 1385*da0073e9SAndroid Build Coastguard Worker (h + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) 1386*da0073e9SAndroid Build Coastguard Worker / stride[0] 1387*da0073e9SAndroid Build Coastguard Worker + 1 1388*da0073e9SAndroid Build Coastguard Worker ) 1389*da0073e9SAndroid Build Coastguard Worker output_width = math.floor( 1390*da0073e9SAndroid Build Coastguard Worker (w + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) 1391*da0073e9SAndroid Build Coastguard Worker / stride[1] 1392*da0073e9SAndroid Build Coastguard Worker + 1 1393*da0073e9SAndroid Build Coastguard Worker ) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype) 1396*da0073e9SAndroid Build Coastguard Worker input = input.contiguous(memory_format=torch.channels_last) 1397*da0073e9SAndroid Build Coastguard Worker grad = torch.randint( 1398*da0073e9SAndroid Build Coastguard Worker 1, 10, (n, c, output_height, output_width), device=device, dtype=dtype 1399*da0073e9SAndroid Build Coastguard Worker ) 1400*da0073e9SAndroid Build Coastguard Worker grad = grad.contiguous(memory_format=torch.channels_last) 1401*da0073e9SAndroid Build Coastguard Worker if not contig: 1402*da0073e9SAndroid Build Coastguard Worker input = input[:, ::2, :, :] 1403*da0073e9SAndroid Build Coastguard Worker grad = grad[:, ::2, :, :] 1404*da0073e9SAndroid Build Coastguard Worker input.requires_grad_(True) 1405*da0073e9SAndroid Build Coastguard Worker pool = torch.nn.MaxPool2d( 1406*da0073e9SAndroid Build Coastguard Worker kernel_size, 1407*da0073e9SAndroid Build Coastguard Worker stride, 1408*da0073e9SAndroid Build Coastguard Worker padding, 1409*da0073e9SAndroid Build Coastguard Worker dilation, 1410*da0073e9SAndroid Build Coastguard Worker return_indices=True, 1411*da0073e9SAndroid Build Coastguard Worker ceil_mode=False, 1412*da0073e9SAndroid Build Coastguard Worker ) 1413*da0073e9SAndroid Build Coastguard Worker 1414*da0073e9SAndroid Build Coastguard Worker ref_input = input.detach().clone().contiguous().requires_grad_(True) 1415*da0073e9SAndroid Build Coastguard Worker ref_grad = grad.detach().clone().contiguous() 1416*da0073e9SAndroid Build Coastguard Worker ref_pool = torch.nn.MaxPool2d( 1417*da0073e9SAndroid Build Coastguard Worker kernel_size, 1418*da0073e9SAndroid Build Coastguard Worker stride, 1419*da0073e9SAndroid Build Coastguard Worker padding, 1420*da0073e9SAndroid Build Coastguard Worker dilation, 1421*da0073e9SAndroid Build Coastguard Worker return_indices=True, 1422*da0073e9SAndroid Build Coastguard Worker ceil_mode=False, 1423*da0073e9SAndroid Build Coastguard Worker ).to(device) 1424*da0073e9SAndroid Build Coastguard Worker 1425*da0073e9SAndroid Build Coastguard Worker out, ind = pool(input) 1426*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 1427*da0073e9SAndroid Build Coastguard Worker ref_out, ref_ind = ref_pool(ref_input) 1428*da0073e9SAndroid Build Coastguard Worker ref_out.backward(ref_grad) 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 1431*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_out.is_contiguous()) 1432*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last)) 1433*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref_ind.is_contiguous()) 1434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out) 1435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ind, ref_ind) 1436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad, ref_input.grad) 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker for contig in [True, False]: 1439*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 10, 10, (2, 2), (1, 1), (1, 1), (2, 2), contig, device) 1440*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 9, 14, (2, 2), (1, 1), (1, 1), (2, 2), contig, device) 1441*da0073e9SAndroid Build Coastguard Worker helper(4, 8, 11, 11, (4, 4), (2, 2), (2, 2), (2, 2), contig, device) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1444*da0073e9SAndroid Build Coastguard Worker def test_pool3d_size_one_feature_dim(self, device): 1445*da0073e9SAndroid Build Coastguard Worker # Tests crazy strides for feature dim of size 1 1446*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7, 1, 5, 3, 2, device=device) 1447*da0073e9SAndroid Build Coastguard Worker strange_strides = [30, 1234, 6, 2, 1] 1448*da0073e9SAndroid Build Coastguard Worker y = x.as_strided(x.size(), strange_strides) 1449*da0073e9SAndroid Build Coastguard Worker x = x.cpu().as_strided(x.size(), strange_strides) 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker to_test = { 1452*da0073e9SAndroid Build Coastguard Worker "max_pool3d": lambda t: F.max_pool3d(t, (5, 1, 1), stride=(5, 1, 1)), 1453*da0073e9SAndroid Build Coastguard Worker "avg_pool3d": lambda t: F.avg_pool3d(t, (5, 1, 1), stride=(5, 1, 1)), 1454*da0073e9SAndroid Build Coastguard Worker } 1455*da0073e9SAndroid Build Coastguard Worker 1456*da0073e9SAndroid Build Coastguard Worker for test, fn in to_test.items(): 1457*da0073e9SAndroid Build Coastguard Worker # Should not crash 1458*da0073e9SAndroid Build Coastguard Worker out_y = fn(y) 1459*da0073e9SAndroid Build Coastguard Worker out_x = fn(x) 1460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_y, out_x.to(device), msg=test) 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1463*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("18GB") 1464*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("180GB", "cpu") 1465*da0073e9SAndroid Build Coastguard Worker def test_pool3d_large_size_int64(self, device): 1466*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/52822 1467*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 1468*da0073e9SAndroid Build Coastguard Worker 70, 32, 100, 100, 100, dtype=torch.half, device=device, requires_grad=True 1469*da0073e9SAndroid Build Coastguard Worker ) 1470*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.max_pool3d(x, 5) 1471*da0073e9SAndroid Build Coastguard Worker g = torch.randn_like(y, dtype=torch.half) 1472*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1473*da0073e9SAndroid Build Coastguard Worker y.backward(g) 1474*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker ref_x = x.detach().cpu().float() # max_pool3d_cpu is not implemented for half 1477*da0073e9SAndroid Build Coastguard Worker ref_x.requires_grad = True 1478*da0073e9SAndroid Build Coastguard Worker ref_g = g.cpu().float() 1479*da0073e9SAndroid Build Coastguard Worker ref_y = torch.nn.functional.max_pool3d(ref_x, 5) 1480*da0073e9SAndroid Build Coastguard Worker ref_y.backward(ref_g) 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, ref_y, exact_dtype=False) 1483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, ref_x.grad, exact_dtype=False) 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1486*da0073e9SAndroid Build Coastguard Worker def test_AvgPool3d_backward_after_cat_dim1_device(self, device): 1487*da0073e9SAndroid Build Coastguard Worker # x has to have batch_size 1 to test contiguous checks 1488*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 3, 4, 4, 4, device=device, requires_grad=True) 1489*da0073e9SAndroid Build Coastguard Worker y = F.avg_pool3d(x, kernel_size=3, padding=1, stride=2) 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(y.size(), device=device) 1492*da0073e9SAndroid Build Coastguard Worker # increase the stride in dimension 0. the tensor is still contiguous because size[0] is 1 1493*da0073e9SAndroid Build Coastguard Worker stride = list(grad.stride()) 1494*da0073e9SAndroid Build Coastguard Worker stride[0] = stride[0] * 2 1495*da0073e9SAndroid Build Coastguard Worker grad.set_(grad.storage(), 0, grad.size(), stride) 1496*da0073e9SAndroid Build Coastguard Worker assert grad.is_contiguous() 1497*da0073e9SAndroid Build Coastguard Worker 1498*da0073e9SAndroid Build Coastguard Worker y.backward(grad) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker def _test_maxpool_indices( 1501*da0073e9SAndroid Build Coastguard Worker self, num_dim, adaptive=False, device="cpu", dtype=torch.float 1502*da0073e9SAndroid Build Coastguard Worker ): 1503*da0073e9SAndroid Build Coastguard Worker def expected_indices(dim, dtype): 1504*da0073e9SAndroid Build Coastguard Worker if dim == 1: 1505*da0073e9SAndroid Build Coastguard Worker return torch.tensor([1, 3], dtype=dtype).repeat(2, 2, 1) 1506*da0073e9SAndroid Build Coastguard Worker if dim == 2: 1507*da0073e9SAndroid Build Coastguard Worker return torch.tensor([[5, 7], [13, 15]], dtype=dtype).repeat(2, 2, 1, 1) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker def expected_grad(dim, dtype): 1510*da0073e9SAndroid Build Coastguard Worker if dim == 1: 1511*da0073e9SAndroid Build Coastguard Worker return torch.tensor([0, 1, 0, 1], dtype=dtype).repeat(2, 2, 1) 1512*da0073e9SAndroid Build Coastguard Worker grad = expected_grad(dim - 1, dtype=dtype) 1513*da0073e9SAndroid Build Coastguard Worker zero = torch.zeros(grad.size(), dtype=dtype) 1514*da0073e9SAndroid Build Coastguard Worker return torch.stack((zero, grad, zero, grad), 2) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker def expected_output(dim, dtype): 1517*da0073e9SAndroid Build Coastguard Worker if dim == 1: 1518*da0073e9SAndroid Build Coastguard Worker return torch.arange(2, 17, 2, dtype=dtype).view(2, 2, 2) 1519*da0073e9SAndroid Build Coastguard Worker if dim == 2: 1520*da0073e9SAndroid Build Coastguard Worker col = torch.arange(6, 63, 8, dtype=dtype) 1521*da0073e9SAndroid Build Coastguard Worker return torch.stack([col, col + 2], 1).view(2, 2, 2, 2) 1522*da0073e9SAndroid Build Coastguard Worker 1523*da0073e9SAndroid Build Coastguard Worker if adaptive: 1524*da0073e9SAndroid Build Coastguard Worker cls_name = "AdaptiveMaxPool{}d".format(num_dim) # noqa: UP032 1525*da0073e9SAndroid Build Coastguard Worker else: 1526*da0073e9SAndroid Build Coastguard Worker # FIXME(#105716): Test fails when using f-string 1527*da0073e9SAndroid Build Coastguard Worker cls_name = "MaxPool{}d".format(num_dim) # noqa: UP032 1528*da0073e9SAndroid Build Coastguard Worker module_cls = getattr(nn, cls_name) 1529*da0073e9SAndroid Build Coastguard Worker module = module_cls(2, return_indices=True).to(device, dtype=dtype) 1530*da0073e9SAndroid Build Coastguard Worker numel = 4 ** (num_dim + 1) 1531*da0073e9SAndroid Build Coastguard Worker input = ( 1532*da0073e9SAndroid Build Coastguard Worker torch.arange(1, numel + 1) 1533*da0073e9SAndroid Build Coastguard Worker .view(2, 2, *repeat(4, num_dim)) 1534*da0073e9SAndroid Build Coastguard Worker .to(device, dtype=dtype) 1535*da0073e9SAndroid Build Coastguard Worker ) 1536*da0073e9SAndroid Build Coastguard Worker input_var = input.clone().detach().requires_grad_() 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker # Check forward 1539*da0073e9SAndroid Build Coastguard Worker output, indices = module(input_var) 1540*da0073e9SAndroid Build Coastguard Worker if num_dim != 3: 1541*da0073e9SAndroid Build Coastguard Worker expected_indices = expected_indices(num_dim, dtype=indices.data.dtype) 1542*da0073e9SAndroid Build Coastguard Worker expected_output = expected_output(num_dim, dtype=output.data.dtype) 1543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices.dim(), input.dim()) 1544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices.data.squeeze(), expected_indices) 1545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.data.squeeze(), expected_output) 1546*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.requires_grad) 1547*da0073e9SAndroid Build Coastguard Worker self.assertFalse(indices.requires_grad) 1548*da0073e9SAndroid Build Coastguard Worker 1549*da0073e9SAndroid Build Coastguard Worker # Make sure backward works 1550*da0073e9SAndroid Build Coastguard Worker grad_output = torch.ones(output.size(), device=device, dtype=dtype) 1551*da0073e9SAndroid Build Coastguard Worker output.backward(grad_output, retain_graph=True) 1552*da0073e9SAndroid Build Coastguard Worker expected_grad = expected_grad(num_dim, dtype=input_var.grad.data.dtype) 1553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_var.grad.data, expected_grad.view_as(input)) 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Worker # Make sure backward after changing indices will result in an error 1556*da0073e9SAndroid Build Coastguard Worker indices.add_(1) 1557*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: output.backward(grad_output)) 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker # Make sure -Infinity is handled correctly 1560*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[[float("-inf")]]]) 1561*da0073e9SAndroid Build Coastguard Worker m = nn.MaxPool1d(kernel_size=1, return_indices=True) 1562*da0073e9SAndroid Build Coastguard Worker output, indices = m(t) 1563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[0, 0, 0], float("-inf")) 1564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices[0, 0, 0], 0) 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[[float("-inf")]]]) 1567*da0073e9SAndroid Build Coastguard Worker m = nn.MaxPool2d(kernel_size=1, return_indices=True) 1568*da0073e9SAndroid Build Coastguard Worker output, indices = m(t) 1569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[0, 0, 0], float("-inf")) 1570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices[0, 0, 0], 0) 1571*da0073e9SAndroid Build Coastguard Worker 1572*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[[[float("-inf")]]]]) 1573*da0073e9SAndroid Build Coastguard Worker m = nn.MaxPool3d(kernel_size=1, return_indices=True) 1574*da0073e9SAndroid Build Coastguard Worker output, indices = m(t) 1575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output[0, 0, 0, 0], float("-inf")) 1576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices[0, 0, 0, 0], 0) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1579*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1580*da0073e9SAndroid Build Coastguard Worker def test_MaxPool1d_indices(self, device, dtype): 1581*da0073e9SAndroid Build Coastguard Worker self._test_maxpool_indices(1, device=device, dtype=dtype) 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1584*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1585*da0073e9SAndroid Build Coastguard Worker def test_MaxPool2d_indices(self, device, dtype): 1586*da0073e9SAndroid Build Coastguard Worker self._test_maxpool_indices(2, device=device, dtype=dtype) 1587*da0073e9SAndroid Build Coastguard Worker 1588*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1589*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1590*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1591*da0073e9SAndroid Build Coastguard Worker def test_MaxPool3d_indices(self, device, dtype): 1592*da0073e9SAndroid Build Coastguard Worker self._test_maxpool_indices(3, device=device, dtype=dtype) 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1595*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1596*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1597*da0073e9SAndroid Build Coastguard Worker def test_AdaptiveMaxPool1d_indices(self, device, dtype): 1598*da0073e9SAndroid Build Coastguard Worker self._test_maxpool_indices(1, adaptive=True, device=device, dtype=dtype) 1599*da0073e9SAndroid Build Coastguard Worker 1600*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1601*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1602*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1603*da0073e9SAndroid Build Coastguard Worker def test_AdaptiveMaxPool2d_indices(self, device, dtype): 1604*da0073e9SAndroid Build Coastguard Worker self._test_maxpool_indices(2, adaptive=True, device=device, dtype=dtype) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1607*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1608*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1609*da0073e9SAndroid Build Coastguard Worker def test_AdaptiveMaxPool3d_indices(self, device, dtype): 1610*da0073e9SAndroid Build Coastguard Worker self._test_maxpool_indices(3, adaptive=True, device=device, dtype=dtype) 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1613*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1614*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1615*da0073e9SAndroid Build Coastguard Worker def test_maxpool_indices_no_batch_dim(self, device, dtype): 1616*da0073e9SAndroid Build Coastguard Worker """Check that indices with no batch dim is consistent with a single batch.""" 1617*da0073e9SAndroid Build Coastguard Worker max_pool_cases = [ 1618*da0073e9SAndroid Build Coastguard Worker ( 1619*da0073e9SAndroid Build Coastguard Worker nn.MaxPool1d(3, return_indices=True), 1620*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5, device=device, dtype=dtype), 1621*da0073e9SAndroid Build Coastguard Worker ), 1622*da0073e9SAndroid Build Coastguard Worker ( 1623*da0073e9SAndroid Build Coastguard Worker nn.MaxPool2d(3, return_indices=True), 1624*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5, 6, device=device, dtype=dtype), 1625*da0073e9SAndroid Build Coastguard Worker ), 1626*da0073e9SAndroid Build Coastguard Worker ( 1627*da0073e9SAndroid Build Coastguard Worker nn.MaxPool3d(3, return_indices=True), 1628*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5, 6, 7, device=device, dtype=dtype), 1629*da0073e9SAndroid Build Coastguard Worker ), 1630*da0073e9SAndroid Build Coastguard Worker ( 1631*da0073e9SAndroid Build Coastguard Worker nn.AdaptiveMaxPool1d(3, return_indices=True), 1632*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5, device=device, dtype=dtype), 1633*da0073e9SAndroid Build Coastguard Worker ), 1634*da0073e9SAndroid Build Coastguard Worker ( 1635*da0073e9SAndroid Build Coastguard Worker nn.AdaptiveMaxPool2d(3, return_indices=True), 1636*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5, 6, device=device, dtype=dtype), 1637*da0073e9SAndroid Build Coastguard Worker ), 1638*da0073e9SAndroid Build Coastguard Worker ( 1639*da0073e9SAndroid Build Coastguard Worker nn.AdaptiveMaxPool3d(3, return_indices=True), 1640*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5, 6, 7, device=device, dtype=dtype), 1641*da0073e9SAndroid Build Coastguard Worker ), 1642*da0073e9SAndroid Build Coastguard Worker ] 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker for module, input in max_pool_cases: 1645*da0073e9SAndroid Build Coastguard Worker _, indices_no_batch = module(input) 1646*da0073e9SAndroid Build Coastguard Worker _, indicies_single_batch = module(input.unsqueeze(0)) 1647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices_no_batch, indicies_single_batch.squeeze(0)) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 1650*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1651*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # TODO: Fails on XLA 1652*da0073e9SAndroid Build Coastguard Worker @gcIfJetson 1653*da0073e9SAndroid Build Coastguard Worker def test_max_pool_nan_inf(self, device, dtype): 1654*da0073e9SAndroid Build Coastguard Worker for adaptive in ["", "adaptive_"]: 1655*da0073e9SAndroid Build Coastguard Worker for num_dim in [1, 2, 3]: 1656*da0073e9SAndroid Build Coastguard Worker fn_name = f"{adaptive}max_pool{num_dim}d" 1657*da0073e9SAndroid Build Coastguard Worker fn = getattr(F, fn_name) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker x = torch.full( 1660*da0073e9SAndroid Build Coastguard Worker [1, 1] + num_dim * [3], 1661*da0073e9SAndroid Build Coastguard Worker nan, 1662*da0073e9SAndroid Build Coastguard Worker device=device, 1663*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1664*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 1665*da0073e9SAndroid Build Coastguard Worker ) 1666*da0073e9SAndroid Build Coastguard Worker res = fn(x, 1 if adaptive else 3) 1667*da0073e9SAndroid Build Coastguard Worker res.backward(torch.randn_like(res)) 1668*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(res.item())) 1669*da0073e9SAndroid Build Coastguard Worker x.requires_grad_(False) 1670*da0073e9SAndroid Build Coastguard Worker res = fn(x, 1 if adaptive else 3) 1671*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(res.item())) 1672*da0073e9SAndroid Build Coastguard Worker 1673*da0073e9SAndroid Build Coastguard Worker x2 = torch.full( 1674*da0073e9SAndroid Build Coastguard Worker [1, 1] + num_dim * [3], 1675*da0073e9SAndroid Build Coastguard Worker -inf, 1676*da0073e9SAndroid Build Coastguard Worker device=device, 1677*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1678*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 1679*da0073e9SAndroid Build Coastguard Worker ) 1680*da0073e9SAndroid Build Coastguard Worker res2 = fn(x2, 1 if adaptive else 3) 1681*da0073e9SAndroid Build Coastguard Worker res2.backward(torch.randn_like(res2)) 1682*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isinf(res2.item())) 1683*da0073e9SAndroid Build Coastguard Worker x2.requires_grad_(False) 1684*da0073e9SAndroid Build Coastguard Worker res2 = fn(x2, 1 if adaptive else 3) 1685*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isinf(res2.item())) 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta 1688*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1689*da0073e9SAndroid Build Coastguard Worker def test_fractional_max_pool2d(self, device): 1690*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1691*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 2, 7, 7, requires_grad=True, device=device) 1692*da0073e9SAndroid Build Coastguard Worker samples = x.new(1, 2, 2).uniform_() 1693*da0073e9SAndroid Build Coastguard Worker 1694*da0073e9SAndroid Build Coastguard Worker def func(x): 1695*da0073e9SAndroid Build Coastguard Worker return F.fractional_max_pool2d( 1696*da0073e9SAndroid Build Coastguard Worker x, (2, 2), output_size=(3, 3), _random_samples=samples 1697*da0073e9SAndroid Build Coastguard Worker ) 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x).shape, (1, 2, 3, 3)) 1700*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [x]) 1701*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [x]) 1702*da0073e9SAndroid Build Coastguard Worker 1703*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 7, 7, requires_grad=True, device=device) 1704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x).shape, (2, 3, 3)) 1705*da0073e9SAndroid Build Coastguard Worker if self.device_type != "cuda": 1706*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/52427 1707*da0073e9SAndroid Build Coastguard Worker # Raises -> RuntimeError: TensorAccessor expected 4 dims but tensor has 3 1708*da0073e9SAndroid Build Coastguard Worker # on CUDA in gradcheck 1709*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [x]) 1710*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [x]) 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker for kernel_size in [(), (1,)]: 1713*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "kernel_size must either"): 1714*da0073e9SAndroid Build Coastguard Worker # Incorrect kernel_size 1715*da0073e9SAndroid Build Coastguard Worker F.fractional_max_pool2d( 1716*da0073e9SAndroid Build Coastguard Worker x, 1717*da0073e9SAndroid Build Coastguard Worker kernel_size=kernel_size, 1718*da0073e9SAndroid Build Coastguard Worker output_size=(3, 3), 1719*da0073e9SAndroid Build Coastguard Worker _random_samples=samples, 1720*da0073e9SAndroid Build Coastguard Worker ) 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker err_large_msg = "too large relative to input " 1723*da0073e9SAndroid Build Coastguard Worker err_out_size_msg = "output_size must either" 1724*da0073e9SAndroid Build Coastguard Worker for output_size, msg in [ 1725*da0073e9SAndroid Build Coastguard Worker ((9, 3), err_large_msg + "height"), 1726*da0073e9SAndroid Build Coastguard Worker ((3, 9), err_large_msg + "width"), 1727*da0073e9SAndroid Build Coastguard Worker ((3,), err_out_size_msg), 1728*da0073e9SAndroid Build Coastguard Worker ((), err_out_size_msg), 1729*da0073e9SAndroid Build Coastguard Worker ]: 1730*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1731*da0073e9SAndroid Build Coastguard Worker # Incorrect output_size 1732*da0073e9SAndroid Build Coastguard Worker F.fractional_max_pool2d( 1733*da0073e9SAndroid Build Coastguard Worker x, (2, 2), output_size=output_size, _random_samples=samples 1734*da0073e9SAndroid Build Coastguard Worker ) 1735*da0073e9SAndroid Build Coastguard Worker 1736*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError: Unrecognized tensor type ID: Meta 1737*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1738*da0073e9SAndroid Build Coastguard Worker def test_fractional_max_pool3d(self, device): 1739*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1740*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 2, 7, 7, 7, requires_grad=True, device=device) 1741*da0073e9SAndroid Build Coastguard Worker samples = x.new(1, 2, 3).uniform_() 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker def func(x): 1744*da0073e9SAndroid Build Coastguard Worker return F.fractional_max_pool3d( 1745*da0073e9SAndroid Build Coastguard Worker x, (2, 2, 2), output_size=(3, 3, 3), _random_samples=samples 1746*da0073e9SAndroid Build Coastguard Worker ) 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x).shape, (1, 2, 3, 3, 3)) 1749*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [x]) 1750*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [x]) 1751*da0073e9SAndroid Build Coastguard Worker 1752*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 7, 7, 7, requires_grad=True, device=device) 1753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x).shape, (2, 3, 3, 3)) 1754*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [x]) 1755*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [x]) 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Worker for kernel_size in [(), (1,), (1, 1)]: 1758*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "kernel_size must either"): 1759*da0073e9SAndroid Build Coastguard Worker # Incorrect kernel_size 1760*da0073e9SAndroid Build Coastguard Worker F.fractional_max_pool3d( 1761*da0073e9SAndroid Build Coastguard Worker x, 1762*da0073e9SAndroid Build Coastguard Worker kernel_size=kernel_size, 1763*da0073e9SAndroid Build Coastguard Worker output_size=(3, 3, 3), 1764*da0073e9SAndroid Build Coastguard Worker _random_samples=samples, 1765*da0073e9SAndroid Build Coastguard Worker ) 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker err_large_msg = "too large relative to input " 1768*da0073e9SAndroid Build Coastguard Worker err_out_size_msg = "output_size must either" 1769*da0073e9SAndroid Build Coastguard Worker for output_size, msg in [ 1770*da0073e9SAndroid Build Coastguard Worker ((9, 3, 3), err_large_msg + "time"), 1771*da0073e9SAndroid Build Coastguard Worker ((3, 9, 3), err_large_msg + "height"), 1772*da0073e9SAndroid Build Coastguard Worker ((3, 3, 9), err_large_msg + "width"), 1773*da0073e9SAndroid Build Coastguard Worker ((3, 3), err_out_size_msg), 1774*da0073e9SAndroid Build Coastguard Worker ((3,), err_out_size_msg), 1775*da0073e9SAndroid Build Coastguard Worker ((), err_out_size_msg), 1776*da0073e9SAndroid Build Coastguard Worker ]: 1777*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1778*da0073e9SAndroid Build Coastguard Worker # Incorrect output_size 1779*da0073e9SAndroid Build Coastguard Worker F.fractional_max_pool3d( 1780*da0073e9SAndroid Build Coastguard Worker x, (2, 2, 2), output_size=output_size, _random_samples=samples 1781*da0073e9SAndroid Build Coastguard Worker ) 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 1784*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1785*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # TODO: Fails on XLA 1786*da0073e9SAndroid Build Coastguard Worker def test_fractional_max_pool_nan_inf(self, device, dtype): 1787*da0073e9SAndroid Build Coastguard Worker for num_dim in [2, 3]: 1788*da0073e9SAndroid Build Coastguard Worker fn_name = f"FractionalMaxPool{num_dim}d" 1789*da0073e9SAndroid Build Coastguard Worker fn = getattr(nn, fn_name)(kernel_size=2, output_size=1) 1790*da0073e9SAndroid Build Coastguard Worker x = torch.full( 1791*da0073e9SAndroid Build Coastguard Worker [1, 1] + num_dim * [3], 1792*da0073e9SAndroid Build Coastguard Worker nan, 1793*da0073e9SAndroid Build Coastguard Worker device=device, 1794*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1795*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 1796*da0073e9SAndroid Build Coastguard Worker ) 1797*da0073e9SAndroid Build Coastguard Worker res = fn(x) 1798*da0073e9SAndroid Build Coastguard Worker res.backward(torch.randn_like(res)) 1799*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(res.item())) 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker x2 = torch.full( 1802*da0073e9SAndroid Build Coastguard Worker [1, 1] + num_dim * [3], 1803*da0073e9SAndroid Build Coastguard Worker -inf, 1804*da0073e9SAndroid Build Coastguard Worker device=device, 1805*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1806*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 1807*da0073e9SAndroid Build Coastguard Worker ) 1808*da0073e9SAndroid Build Coastguard Worker res2 = fn(x2) 1809*da0073e9SAndroid Build Coastguard Worker res2.backward(torch.randn_like(res2)) 1810*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isinf(res2.item())) 1811*da0073e9SAndroid Build Coastguard Worker 1812*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # TODO: RuntimeError message different on XLA 1813*da0073e9SAndroid Build Coastguard Worker def test_pooling_zero_stride(self, device): 1814*da0073e9SAndroid Build Coastguard Worker for op in ("max", "avg"): 1815*da0073e9SAndroid Build Coastguard Worker for num_dim in [1, 2, 3]: 1816*da0073e9SAndroid Build Coastguard Worker fn_name = f"{op}_pool{num_dim}d" 1817*da0073e9SAndroid Build Coastguard Worker fn = getattr(F, fn_name) 1818*da0073e9SAndroid Build Coastguard Worker x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float) 1819*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1820*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1821*da0073e9SAndroid Build Coastguard Worker r"stride should not be zero|stride must be greater than zero", 1822*da0073e9SAndroid Build Coastguard Worker lambda: fn(x, kernel_size=2, stride=0), 1823*da0073e9SAndroid Build Coastguard Worker ) 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker fn_module_name = f"{op.title()}Pool{num_dim}d" 1826*da0073e9SAndroid Build Coastguard Worker fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0) 1827*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1828*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1829*da0073e9SAndroid Build Coastguard Worker r"stride should not be zero|stride must be greater than zero", 1830*da0073e9SAndroid Build Coastguard Worker lambda: fn_module(x), 1831*da0073e9SAndroid Build Coastguard Worker ) 1832*da0073e9SAndroid Build Coastguard Worker 1833*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1834*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1835*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1836*da0073e9SAndroid Build Coastguard Worker def test_pool_large_size(self, device, dtype): 1837*da0073e9SAndroid Build Coastguard Worker for op in ("max", "avg"): 1838*da0073e9SAndroid Build Coastguard Worker for num_dim in [1, 2, 3]: 1839*da0073e9SAndroid Build Coastguard Worker fn_name = f"{op}_pool{num_dim}d" 1840*da0073e9SAndroid Build Coastguard Worker fn = getattr(F, fn_name) 1841*da0073e9SAndroid Build Coastguard Worker # 16777217 is the smallest integer not expressible in float32 1842*da0073e9SAndroid Build Coastguard Worker x = torch.ones( 1843*da0073e9SAndroid Build Coastguard Worker [1, 1, 16777217] + (num_dim - 1) * [1], device=device, dtype=dtype 1844*da0073e9SAndroid Build Coastguard Worker ) 1845*da0073e9SAndroid Build Coastguard Worker res = fn(x, 1, stride=1, padding=0) 1846*da0073e9SAndroid Build Coastguard Worker # check if the output shape was still computed correctly 1847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape[2], res.shape[2]) 1848*da0073e9SAndroid Build Coastguard Worker 1849*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1850*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("6GB") 1851*da0073e9SAndroid Build Coastguard Worker def test_pooling_large(self, device): 1852*da0073e9SAndroid Build Coastguard Worker def helper(pool): 1853*da0073e9SAndroid Build Coastguard Worker inp = torch.randn( 1854*da0073e9SAndroid Build Coastguard Worker 2**7 + 10, 2**8, 2**8, 2**8, dtype=torch.half, device="cuda" 1855*da0073e9SAndroid Build Coastguard Worker ) 1856*da0073e9SAndroid Build Coastguard Worker self.assertTrue(inp.numel() > 2**31 - 1) 1857*da0073e9SAndroid Build Coastguard Worker out = pool(inp) 1858*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() # asserts test finishes normally without raising errors 1859*da0073e9SAndroid Build Coastguard Worker 1860*da0073e9SAndroid Build Coastguard Worker helper(nn.MaxPool2d(4, 4)) 1861*da0073e9SAndroid Build Coastguard Worker helper(nn.AvgPool2d(4, 4)) 1862*da0073e9SAndroid Build Coastguard Worker helper(nn.FractionalMaxPool2d(4, 4)) 1863*da0073e9SAndroid Build Coastguard Worker helper(nn.AdaptiveMaxPool2d((2**6, 2**6))) 1864*da0073e9SAndroid Build Coastguard Worker helper(nn.AdaptiveAvgPool2d((2**6, 2**6))) 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 1867*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1868*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1869*da0073e9SAndroid Build Coastguard Worker def test_pool_invalid_size(self, device, dtype): 1870*da0073e9SAndroid Build Coastguard Worker for op in ("max", "avg"): 1871*da0073e9SAndroid Build Coastguard Worker for num_dim in [1, 2, 3]: 1872*da0073e9SAndroid Build Coastguard Worker fn_name = f"{op}_pool{num_dim}d" 1873*da0073e9SAndroid Build Coastguard Worker if op == "max": 1874*da0073e9SAndroid Build Coastguard Worker # New implementation without indices supports empty tensors 1875*da0073e9SAndroid Build Coastguard Worker # TODO(Heitor) change once with_indices code is updated 1876*da0073e9SAndroid Build Coastguard Worker fn_name += "_with_indices" 1877*da0073e9SAndroid Build Coastguard Worker fn = getattr(F, fn_name) 1878*da0073e9SAndroid Build Coastguard Worker # use a configuration that gives zero outputs only 1879*da0073e9SAndroid Build Coastguard Worker # when doing a correct floor division by the stride 1880*da0073e9SAndroid Build Coastguard Worker x = torch.ones([1, 1] + num_dim * [4], device=device, dtype=dtype) 1881*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"too small|smaller than"): 1882*da0073e9SAndroid Build Coastguard Worker try: 1883*da0073e9SAndroid Build Coastguard Worker res = fn(x, 3, stride=2, padding=0, dilation=2) 1884*da0073e9SAndroid Build Coastguard Worker except TypeError: 1885*da0073e9SAndroid Build Coastguard Worker # some implementations do not support dilation 1886*da0073e9SAndroid Build Coastguard Worker res = fn(x, 6, stride=2, padding=0) 1887*da0073e9SAndroid Build Coastguard Worker 1888*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1889*da0073e9SAndroid Build Coastguard Worker def test_pooling_bfloat16(self, device): 1890*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops( 1891*da0073e9SAndroid Build Coastguard Worker self, 1892*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool1d(3, stride=2), 1893*da0073e9SAndroid Build Coastguard Worker device, 1894*da0073e9SAndroid Build Coastguard Worker inp_dims=(8, 4, 16), 1895*da0073e9SAndroid Build Coastguard Worker prec=0.05, 1896*da0073e9SAndroid Build Coastguard Worker ) 1897*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops( 1898*da0073e9SAndroid Build Coastguard Worker self, 1899*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool2d(3, stride=2), 1900*da0073e9SAndroid Build Coastguard Worker device, 1901*da0073e9SAndroid Build Coastguard Worker inp_dims=(8, 4, 16, 16), 1902*da0073e9SAndroid Build Coastguard Worker prec=0.05, 1903*da0073e9SAndroid Build Coastguard Worker ) 1904*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops( 1905*da0073e9SAndroid Build Coastguard Worker self, 1906*da0073e9SAndroid Build Coastguard Worker torch.nn.AvgPool3d(3, stride=2), 1907*da0073e9SAndroid Build Coastguard Worker device, 1908*da0073e9SAndroid Build Coastguard Worker inp_dims=(8, 4, 16, 16, 16), 1909*da0073e9SAndroid Build Coastguard Worker prec=0.05, 1910*da0073e9SAndroid Build Coastguard Worker ) 1911*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops( 1912*da0073e9SAndroid Build Coastguard Worker self, torch.nn.AdaptiveAvgPool1d(3), device, inp_dims=(8, 4, 16), prec=0.05 1913*da0073e9SAndroid Build Coastguard Worker ) 1914*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops( 1915*da0073e9SAndroid Build Coastguard Worker self, 1916*da0073e9SAndroid Build Coastguard Worker torch.nn.AdaptiveAvgPool2d((3, 5)), 1917*da0073e9SAndroid Build Coastguard Worker device, 1918*da0073e9SAndroid Build Coastguard Worker inp_dims=(8, 4, 16, 16), 1919*da0073e9SAndroid Build Coastguard Worker prec=0.05, 1920*da0073e9SAndroid Build Coastguard Worker ) 1921*da0073e9SAndroid Build Coastguard Worker _test_bfloat16_ops( 1922*da0073e9SAndroid Build Coastguard Worker self, 1923*da0073e9SAndroid Build Coastguard Worker torch.nn.AdaptiveAvgPool3d((3, 5, 7)), 1924*da0073e9SAndroid Build Coastguard Worker device, 1925*da0073e9SAndroid Build Coastguard Worker inp_dims=(8, 4, 16, 16, 16), 1926*da0073e9SAndroid Build Coastguard Worker prec=0.05, 1927*da0073e9SAndroid Build Coastguard Worker ) 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker def test_maxpool3d_non_square_backward(self, device): 1930*da0073e9SAndroid Build Coastguard Worker # previous CUDA routine of this backward calculates kernel launch grid size 1931*da0073e9SAndroid Build Coastguard Worker # with last two dimensions interchanged, so the tailing along the longer dim 1932*da0073e9SAndroid Build Coastguard Worker # get ignored. Here we test whether every position gets gradient. 1933*da0073e9SAndroid Build Coastguard Worker for dim in (2, 3, 4): 1934*da0073e9SAndroid Build Coastguard Worker shape = tuple(32 if i != dim else 256 for i in range(4)) 1935*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device, requires_grad=True) 1936*da0073e9SAndroid Build Coastguard Worker F.max_pool3d(x, kernel_size=(1, 1, 1)).sum().backward() 1937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x.grad)) 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker @slowTest 1940*da0073e9SAndroid Build Coastguard Worker def test_adaptive_pool_odd_size(self, device): 1941*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/81409 1942*da0073e9SAndroid Build Coastguard Worker Ih, Iw, Oh, Ow = 5873, 3693, 3527, 2219 1943*da0073e9SAndroid Build Coastguard Worker imgs = torch.randint(low=0, high=256, size=(11, Ih, Iw), dtype=torch.float) 1944*da0073e9SAndroid Build Coastguard Worker imgs_ = F.adaptive_avg_pool2d(imgs, (Oh, Ow)) 1945*da0073e9SAndroid Build Coastguard Worker imgs_ = F.adaptive_max_pool2d(imgs, (Oh, Ow)) 1946*da0073e9SAndroid Build Coastguard Worker 1947*da0073e9SAndroid Build Coastguard Worker Id, Ih, Iw, Od, Oh, Ow = 3, 5873, 3693, 3, 3527, 2219 1948*da0073e9SAndroid Build Coastguard Worker imgs = torch.randint(low=0, high=256, size=(3, Id, Ih, Iw), dtype=torch.float) 1949*da0073e9SAndroid Build Coastguard Worker imgs_ = F.adaptive_avg_pool3d(imgs, (Od, Oh, Ow)) 1950*da0073e9SAndroid Build Coastguard Worker imgs_ = F.adaptive_max_pool3d(imgs, (Od, Oh, Ow)) 1951*da0073e9SAndroid Build Coastguard Worker 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestPoolingNNDeviceType, globals()) 1954*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestPoolingNN) 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1957*da0073e9SAndroid Build Coastguard Worker run_tests() 1958