1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Workerimport random 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 5*da0073e9SAndroid Build Coastguard Workerfrom itertools import combinations, permutations, product 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport numpy as np 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 12*da0073e9SAndroid Build Coastguard Worker dtypes, 13*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 14*da0073e9SAndroid Build Coastguard Worker onlyCPU, 15*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 16*da0073e9SAndroid Build Coastguard Worker skipLazy, 17*da0073e9SAndroid Build Coastguard Worker skipMeta, 18*da0073e9SAndroid Build Coastguard Worker skipXLA, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 21*da0073e9SAndroid Build Coastguard Worker all_types_and, 22*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and, 23*da0073e9SAndroid Build Coastguard Worker complex_types, 24*da0073e9SAndroid Build Coastguard Worker floating_and_complex_types_and, 25*da0073e9SAndroid Build Coastguard Worker) 26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 27*da0073e9SAndroid Build Coastguard Worker gradcheck, 28*da0073e9SAndroid Build Coastguard Worker gradgradcheck, 29*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 30*da0073e9SAndroid Build Coastguard Worker numpy_to_torch_dtype_dict, 31*da0073e9SAndroid Build Coastguard Worker run_tests, 32*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 33*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 34*da0073e9SAndroid Build Coastguard Worker TestCase, 35*da0073e9SAndroid Build Coastguard Worker) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker# TODO: replace this with make_tensor() in common_utils.py 39*da0073e9SAndroid Build Coastguard Workerdef _generate_input(shape, dtype, device, with_extremal): 40*da0073e9SAndroid Build Coastguard Worker if shape == (): 41*da0073e9SAndroid Build Coastguard Worker x = torch.tensor((), dtype=dtype, device=device) 42*da0073e9SAndroid Build Coastguard Worker else: 43*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 44*da0073e9SAndroid Build Coastguard Worker # work around torch.randn not being implemented for bfloat16 45*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 46*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, device=device) * random.randint(30, 100) 47*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.bfloat16) 48*da0073e9SAndroid Build Coastguard Worker else: 49*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( 50*da0073e9SAndroid Build Coastguard Worker 30, 100 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = 0 53*da0073e9SAndroid Build Coastguard Worker if with_extremal and dtype.is_floating_point: 54*da0073e9SAndroid Build Coastguard Worker # Use extremal values 55*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("nan") 56*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("inf") 57*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("-inf") 58*da0073e9SAndroid Build Coastguard Worker elif with_extremal and dtype.is_complex: 59*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("nan") 60*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("inf") 61*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("-inf") 62*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.bool: 63*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(shape, dtype=dtype, device=device) 64*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = True 65*da0073e9SAndroid Build Coastguard Worker else: 66*da0073e9SAndroid Build Coastguard Worker x = torch.randint(15, 100, shape, dtype=dtype, device=device) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker return x 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker# TODO: replace this with make_tensor() in common_utils.py 72*da0073e9SAndroid Build Coastguard Workerdef _rand_shape(dim, min_size, max_size): 73*da0073e9SAndroid Build Coastguard Worker shape = [] 74*da0073e9SAndroid Build Coastguard Worker for i in range(dim): 75*da0073e9SAndroid Build Coastguard Worker shape.append(random.randint(min_size, max_size)) 76*da0073e9SAndroid Build Coastguard Worker return tuple(shape) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker# TODO: refactor tests to avoid this function 80*da0073e9SAndroid Build Coastguard Worker# Converts half/bfloat16 dtype to float when device is cpu 81*da0073e9SAndroid Build Coastguard Workerdef _convert_t(dtype, device): 82*da0073e9SAndroid Build Coastguard Worker if device == "cpu" and dtype in {torch.half, torch.bfloat16}: 83*da0073e9SAndroid Build Coastguard Worker return torch.float 84*da0073e9SAndroid Build Coastguard Worker return dtype 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker# TODO: replace this with make_tensor() in common_utils.py 88*da0073e9SAndroid Build Coastguard Worker# Returns a tensor of the requested shape, dtype, and device 89*da0073e9SAndroid Build Coastguard Worker# Requesting a half CPU tensor returns a float CPU tensor with 90*da0073e9SAndroid Build Coastguard Worker# values representable by a half. 91*da0073e9SAndroid Build Coastguard Worker# Initialization uses randint for non-float types and randn for float types. 92*da0073e9SAndroid Build Coastguard Workerdef _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: 93*da0073e9SAndroid Build Coastguard Worker # Returns a tensor filled with ones 94*da0073e9SAndroid Build Coastguard Worker if fill_ones: 95*da0073e9SAndroid Build Coastguard Worker return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker # Returns a tensor with random integer values 98*da0073e9SAndroid Build Coastguard Worker if not (dtype.is_floating_point or dtype.is_complex): 99*da0073e9SAndroid Build Coastguard Worker t = torch.randint(0, 10, shape, device=device) 100*da0073e9SAndroid Build Coastguard Worker if dtype != torch.uint8: 101*da0073e9SAndroid Build Coastguard Worker t = t - 5 # generate negative values also 102*da0073e9SAndroid Build Coastguard Worker return t.to(_convert_t(dtype, device)) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker # Populates the CPU tensor with floats representable as half/bfloat16 105*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half and device == "cpu": 106*da0073e9SAndroid Build Coastguard Worker return torch.randn(*shape, dtype=torch.float, device=device).half().float() 107*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16 and device == "cpu": 108*da0073e9SAndroid Build Coastguard Worker return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float() 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker # Default: returns a tensor with random float values 111*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker# Tests ops and indexing to ensure they return views (and new tensors) as 115*da0073e9SAndroid Build Coastguard Worker# appropriate. 116*da0073e9SAndroid Build Coastguard Workerclass TestViewOps(TestCase): 117*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def is_view_of(self, base, other): 120*da0073e9SAndroid Build Coastguard Worker if ( 121*da0073e9SAndroid Build Coastguard Worker not other._is_view() 122*da0073e9SAndroid Build Coastguard Worker or other is base 123*da0073e9SAndroid Build Coastguard Worker or other._base is not base 124*da0073e9SAndroid Build Coastguard Worker or base.device != other.device 125*da0073e9SAndroid Build Coastguard Worker ): 126*da0073e9SAndroid Build Coastguard Worker return False 127*da0073e9SAndroid Build Coastguard Worker # Note: only validates storage on native device types 128*da0073e9SAndroid Build Coastguard Worker # because some accelerators, like XLA, do not expose storage 129*da0073e9SAndroid Build Coastguard Worker if base.device.type == "cpu" or base.device.type == "cuda": 130*da0073e9SAndroid Build Coastguard Worker if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr(): 131*da0073e9SAndroid Build Coastguard Worker return False 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker return True 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker # Returns true if v1 and v2 are views of the same base 136*da0073e9SAndroid Build Coastguard Worker def is_view_of_same_base(self, v1, v2): 137*da0073e9SAndroid Build Coastguard Worker if not v1._is_view() or v1 is v2: 138*da0073e9SAndroid Build Coastguard Worker return False 139*da0073e9SAndroid Build Coastguard Worker return self.is_view_of(v1._base, v2) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker # Performs transpose if contiguous=True, else returns the input tensor as is 142*da0073e9SAndroid Build Coastguard Worker def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): 143*da0073e9SAndroid Build Coastguard Worker if contiguous: 144*da0073e9SAndroid Build Coastguard Worker return x 145*da0073e9SAndroid Build Coastguard Worker else: 146*da0073e9SAndroid Build Coastguard Worker return x.transpose(dim0, dim1) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 149*da0073e9SAndroid Build Coastguard Worker def test_conj_self(self, device, dtype): 150*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 151*da0073e9SAndroid Build Coastguard Worker s = t.conj() 152*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s is t) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 155*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 156*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 157*da0073e9SAndroid Build Coastguard Worker def test_view_dtype_new(self, device, dtype): 158*da0073e9SAndroid Build Coastguard Worker dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} 159*da0073e9SAndroid Build Coastguard Worker del dtypes[torch.bool] 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def generate_inputs(): 162*da0073e9SAndroid Build Coastguard Worker yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5) 163*da0073e9SAndroid Build Coastguard Worker yield make_tensor( 164*da0073e9SAndroid Build Coastguard Worker (4, 4, 64), dtype=dtype, device=device, low=-5, high=5 165*da0073e9SAndroid Build Coastguard Worker ).permute(1, 0, 2) 166*da0073e9SAndroid Build Coastguard Worker yield make_tensor( 167*da0073e9SAndroid Build Coastguard Worker (4, 64, 4), dtype=dtype, device=device, low=-5, high=5 168*da0073e9SAndroid Build Coastguard Worker ).permute(2, 0, 1) 169*da0073e9SAndroid Build Coastguard Worker yield make_tensor( 170*da0073e9SAndroid Build Coastguard Worker (1, 5, 1), dtype=dtype, device=device, low=-5, high=5 171*da0073e9SAndroid Build Coastguard Worker ).expand(5, 5, 64) 172*da0073e9SAndroid Build Coastguard Worker yield make_tensor((2, 5, 256), dtype=dtype, device=device, low=-5, high=5)[ 173*da0073e9SAndroid Build Coastguard Worker 1::2, 1:, ::2 174*da0073e9SAndroid Build Coastguard Worker ] 175*da0073e9SAndroid Build Coastguard Worker yield make_tensor((0, 5, 64), dtype=dtype, device=device, low=-5, high=5) 176*da0073e9SAndroid Build Coastguard Worker yield make_tensor((), dtype=dtype, device=device, low=-5, high=5) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def calc_expected_size_and_stride(a, view_dtype): 179*da0073e9SAndroid Build Coastguard Worker dtype_size = torch._utils._element_size(a.dtype) 180*da0073e9SAndroid Build Coastguard Worker view_dtype_size = torch._utils._element_size(view_dtype) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker if dtype_size == view_dtype_size: 183*da0073e9SAndroid Build Coastguard Worker return a.size(), a.stride() 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker elif dtype_size > view_dtype_size: 186*da0073e9SAndroid Build Coastguard Worker size_ratio = dtype_size // view_dtype_size 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker view_size = list(a.size()) 189*da0073e9SAndroid Build Coastguard Worker view_size[-1] = view_size[-1] * size_ratio 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker view_stride = [stride * size_ratio for stride in a.stride()] 192*da0073e9SAndroid Build Coastguard Worker view_stride[-1] = 1 193*da0073e9SAndroid Build Coastguard Worker return torch.Size(view_size), tuple(view_stride) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker else: 196*da0073e9SAndroid Build Coastguard Worker size_ratio = view_dtype_size // dtype_size 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker view_size = list(a.size()) 199*da0073e9SAndroid Build Coastguard Worker view_size[-1] = view_size[-1] // size_ratio 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker view_stride = [stride // size_ratio for stride in a.stride()] 202*da0073e9SAndroid Build Coastguard Worker view_stride[-1] = 1 203*da0073e9SAndroid Build Coastguard Worker return torch.Size(view_size), tuple(view_stride) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker for a in generate_inputs(): 206*da0073e9SAndroid Build Coastguard Worker a_np = a.cpu().numpy() 207*da0073e9SAndroid Build Coastguard Worker a_np_contiguous = a.cpu().contiguous().numpy() 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker for view_dtype, np_view_dtype in dtypes.items(): 210*da0073e9SAndroid Build Coastguard Worker equal_element_size = torch._utils._element_size( 211*da0073e9SAndroid Build Coastguard Worker dtype 212*da0073e9SAndroid Build Coastguard Worker ) == torch._utils._element_size(view_dtype) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker if not equal_element_size and a.dim() == 0: 215*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 216*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"self.dim\(\) cannot be 0" 217*da0073e9SAndroid Build Coastguard Worker ): 218*da0073e9SAndroid Build Coastguard Worker a.view(view_dtype) 219*da0073e9SAndroid Build Coastguard Worker continue 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker if not equal_element_size and a.stride(-1) != 1: 222*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 223*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"self.stride\(-1\) must be 1" 224*da0073e9SAndroid Build Coastguard Worker ): 225*da0073e9SAndroid Build Coastguard Worker a.view(view_dtype) 226*da0073e9SAndroid Build Coastguard Worker continue 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker a_view = a.view(view_dtype) 229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_view.dtype, view_dtype) 230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.data_ptr(), a_view.data_ptr()) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker expected_size, expected_stride = calc_expected_size_and_stride( 233*da0073e9SAndroid Build Coastguard Worker a, view_dtype 234*da0073e9SAndroid Build Coastguard Worker ) 235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_view.size(), expected_size) 236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_view.stride(), expected_stride) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_view.view(dtype), a, rtol=0, atol=0) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker # NumPy's dtype view requires contiguous input if target 241*da0073e9SAndroid Build Coastguard Worker # dtype is a different size 242*da0073e9SAndroid Build Coastguard Worker if equal_element_size: 243*da0073e9SAndroid Build Coastguard Worker a_np_view = a_np.view(np_view_dtype) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker else: 246*da0073e9SAndroid Build Coastguard Worker a_np_view = a_np_contiguous.view(np_view_dtype) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_view, a_np_view) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker # Test that requires_grad is dropped for floating point casts, 251*da0073e9SAndroid Build Coastguard Worker # because view(dtype) does not support backward yet 252*da0073e9SAndroid Build Coastguard Worker # TODO: Remove this when autograd support is added 253*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 254*da0073e9SAndroid Build Coastguard Worker for view_dtype in floating_and_complex_types_and( 255*da0073e9SAndroid Build Coastguard Worker torch.half, torch.bfloat16 256*da0073e9SAndroid Build Coastguard Worker ): 257*da0073e9SAndroid Build Coastguard Worker t = make_tensor( 258*da0073e9SAndroid Build Coastguard Worker (5, 5, 64), 259*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 260*da0073e9SAndroid Build Coastguard Worker device=device, 261*da0073e9SAndroid Build Coastguard Worker low=-5, 262*da0073e9SAndroid Build Coastguard Worker high=5, 263*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 264*da0073e9SAndroid Build Coastguard Worker ) 265*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t.view(view_dtype).requires_grad) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker # Test the extra error checks that happen when the view dtype 268*da0073e9SAndroid Build Coastguard Worker # has a greater element size than the original dtype 269*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 270*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 271*da0073e9SAndroid Build Coastguard Worker def test_view_dtype_upsize_errors(self, device, dtype): 272*da0073e9SAndroid Build Coastguard Worker dtype_size = torch._utils._element_size(dtype) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker for view_dtype in all_types_and_complex_and( 275*da0073e9SAndroid Build Coastguard Worker torch.half, torch.bfloat16, torch.bool 276*da0073e9SAndroid Build Coastguard Worker ): 277*da0073e9SAndroid Build Coastguard Worker view_dtype_size = torch._utils._element_size(view_dtype) 278*da0073e9SAndroid Build Coastguard Worker if view_dtype_size <= dtype_size: 279*da0073e9SAndroid Build Coastguard Worker continue 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker size_ratio = view_dtype_size // dtype_size 282*da0073e9SAndroid Build Coastguard Worker a = make_tensor( 283*da0073e9SAndroid Build Coastguard Worker (4, 4, size_ratio + 1), dtype=dtype, device=device, low=-5, high=5 284*da0073e9SAndroid Build Coastguard Worker ) 285*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 286*da0073e9SAndroid Build Coastguard Worker RuntimeError, rf"self.size\(-1\) must be divisible by {size_ratio}" 287*da0073e9SAndroid Build Coastguard Worker ): 288*da0073e9SAndroid Build Coastguard Worker a.view(view_dtype) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 291*da0073e9SAndroid Build Coastguard Worker RuntimeError, 292*da0073e9SAndroid Build Coastguard Worker rf"self.storage_offset\(\) must be divisible by {size_ratio}", 293*da0073e9SAndroid Build Coastguard Worker ): 294*da0073e9SAndroid Build Coastguard Worker a[:, :, 1:].view(view_dtype) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker a = make_tensor( 297*da0073e9SAndroid Build Coastguard Worker (4, 4, size_ratio), dtype=dtype, device=device, low=-5, high=5 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker a = a.as_strided((4, 4, size_ratio), (size_ratio, 1, 1)) 300*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 301*da0073e9SAndroid Build Coastguard Worker RuntimeError, rf"self.stride\(1\) must be divisible by {size_ratio}" 302*da0073e9SAndroid Build Coastguard Worker ): 303*da0073e9SAndroid Build Coastguard Worker a.view(view_dtype) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 306*da0073e9SAndroid Build Coastguard Worker def test_view_as_complex(self, device): 307*da0073e9SAndroid Build Coastguard Worker def fn(contiguous_input=True, dim0=0, dim1=1): 308*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 2, 2, device=device) 309*da0073e9SAndroid Build Coastguard Worker c_t = t[:, :, 0] + 1j * t[:, :, 1] 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker input = self._do_transpose(t, contiguous_input, dim0, dim1) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker if input.size()[-1] != 2: 314*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 315*da0073e9SAndroid Build Coastguard Worker RuntimeError, 316*da0073e9SAndroid Build Coastguard Worker "Tensor must have a last dimension of size 2", 317*da0073e9SAndroid Build Coastguard Worker lambda: torch.view_as_complex(input), 318*da0073e9SAndroid Build Coastguard Worker ) 319*da0073e9SAndroid Build Coastguard Worker return 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker if input.stride()[-1] != 1: 322*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 323*da0073e9SAndroid Build Coastguard Worker RuntimeError, 324*da0073e9SAndroid Build Coastguard Worker "Tensor must have a last dimension with stride 1", 325*da0073e9SAndroid Build Coastguard Worker lambda: torch.view_as_complex(input), 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker return 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker res = torch.view_as_complex(input) 330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, self._do_transpose(c_t, contiguous_input, dim0, dim1)) 331*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, res)) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker fn() 334*da0073e9SAndroid Build Coastguard Worker fn(contiguous_input=False) 335*da0073e9SAndroid Build Coastguard Worker # RuntimeError since in this case the last dim of input would not be of size 2 336*da0073e9SAndroid Build Coastguard Worker fn(contiguous_input=False, dim0=0, dim1=2) 337*da0073e9SAndroid Build Coastguard Worker # RuntimeError since in this case the last dim of input would not have stride 1 338*da0073e9SAndroid Build Coastguard Worker fn(contiguous_input=False, dim0=1, dim1=2) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker # RuntimeError since in this case the stride of non-last dim of input would not be of size 2 341*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, device=device) 342*da0073e9SAndroid Build Coastguard Worker t = torch.as_strided(x, (2, 2), (1, 1)) 343*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 344*da0073e9SAndroid Build Coastguard Worker RuntimeError, 345*da0073e9SAndroid Build Coastguard Worker "Tensor must have a stride divisible by 2 for all but last dimension", 346*da0073e9SAndroid Build Coastguard Worker lambda: torch.view_as_complex(t), 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker # tensor with zero elements 350*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([], device=device) # torch.Size([0]) 351*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 352*da0073e9SAndroid Build Coastguard Worker RuntimeError, 353*da0073e9SAndroid Build Coastguard Worker "Tensor must have a last dimension of size 2", 354*da0073e9SAndroid Build Coastguard Worker lambda: torch.view_as_complex(x), 355*da0073e9SAndroid Build Coastguard Worker ) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker # zero dimension tensor 358*da0073e9SAndroid Build Coastguard Worker z = torch.tensor(2.0) 359*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 360*da0073e9SAndroid Build Coastguard Worker RuntimeError, 361*da0073e9SAndroid Build Coastguard Worker "Input tensor must have one or more dimensions", 362*da0073e9SAndroid Build Coastguard Worker lambda: torch.view_as_complex(z), 363*da0073e9SAndroid Build Coastguard Worker ) 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker y = x.reshape(0, 2) # torch.Size([0, 2]) 366*da0073e9SAndroid Build Coastguard Worker res = torch.view_as_complex(y) 367*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(x, res)) 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, torch.Size([0])) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 371*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types(), torch.complex32) 372*da0073e9SAndroid Build Coastguard Worker def test_view_as_real(self, device, dtype): 373*da0073e9SAndroid Build Coastguard Worker def fn(contiguous_input=True): 374*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 4, dtype=dtype, device=device) 375*da0073e9SAndroid Build Coastguard Worker input = self._do_transpose(t, contiguous_input) 376*da0073e9SAndroid Build Coastguard Worker res = torch.view_as_real(input) 377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[:, :, 0], input.real) 378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[:, :, 1], input.imag) 379*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, res)) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker fn() 382*da0073e9SAndroid Build Coastguard Worker fn(contiguous_input=False) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker # tensor with zero elements 385*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([], dtype=dtype, device=device) 386*da0073e9SAndroid Build Coastguard Worker res = torch.view_as_real(x) 387*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(x, res)) 388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, torch.Size([0, 2])) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker # tensor with zero dim 391*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2 + 3j, dtype=dtype, device=device) 392*da0073e9SAndroid Build Coastguard Worker res = torch.view_as_real(x) 393*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(x, res)) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.shape, torch.Size([2])) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 397*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 398*da0073e9SAndroid Build Coastguard Worker def test_view_tensor_split(self, device, dtype): 399*da0073e9SAndroid Build Coastguard Worker a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9) 400*da0073e9SAndroid Build Coastguard Worker a_split_dim0 = a.tensor_split(7, 0) 401*da0073e9SAndroid Build Coastguard Worker for a_split_dim0_tensor in a_split_dim0: 402*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(a, a_split_dim0_tensor)) 403*da0073e9SAndroid Build Coastguard Worker a_split_dim1 = a.tensor_split(7, 1) 404*da0073e9SAndroid Build Coastguard Worker for a_split_dim1_tensor in a_split_dim1: 405*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(a, a_split_dim1_tensor)) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 408*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 409*da0073e9SAndroid Build Coastguard Worker def test_view_tensor_hsplit(self, device, dtype): 410*da0073e9SAndroid Build Coastguard Worker t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) 411*da0073e9SAndroid Build Coastguard Worker t_hsplit = torch.hsplit(t, 2) 412*da0073e9SAndroid Build Coastguard Worker for t_hsplit_tensor in t_hsplit: 413*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, t_hsplit_tensor)) 414*da0073e9SAndroid Build Coastguard Worker t[2, 2, 2] = 7 415*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2]) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 418*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 419*da0073e9SAndroid Build Coastguard Worker def test_view_tensor_vsplit(self, device, dtype): 420*da0073e9SAndroid Build Coastguard Worker t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) 421*da0073e9SAndroid Build Coastguard Worker t_vsplit = torch.vsplit(t, 2) 422*da0073e9SAndroid Build Coastguard Worker for t_vsplit_tensor in t_vsplit: 423*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, t_vsplit_tensor)) 424*da0073e9SAndroid Build Coastguard Worker t[2, 2, 2] = 7 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2]) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 428*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 429*da0073e9SAndroid Build Coastguard Worker def test_view_tensor_dsplit(self, device, dtype): 430*da0073e9SAndroid Build Coastguard Worker t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) 431*da0073e9SAndroid Build Coastguard Worker t_dsplit = torch.dsplit(t, 2) 432*da0073e9SAndroid Build Coastguard Worker for t_dsplit_tensor in t_dsplit: 433*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, t_dsplit_tensor)) 434*da0073e9SAndroid Build Coastguard Worker t[2, 2, 2] = 7 435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 438*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 439*da0073e9SAndroid Build Coastguard Worker def test_imag_noncomplex(self, device, dtype): 440*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), dtype=dtype, device=device) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 443*da0073e9SAndroid Build Coastguard Worker torch.imag(t) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 446*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 447*da0073e9SAndroid Build Coastguard Worker def test_real_imag_view(self, device, dtype): 448*da0073e9SAndroid Build Coastguard Worker def compare_with_numpy(contiguous_input=True): 449*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 3, dtype=dtype, device=device) 450*da0073e9SAndroid Build Coastguard Worker if not contiguous_input: 451*da0073e9SAndroid Build Coastguard Worker u = t.T 452*da0073e9SAndroid Build Coastguard Worker else: 453*da0073e9SAndroid Build Coastguard Worker u = t 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker re = u.real 456*da0073e9SAndroid Build Coastguard Worker exp = torch.from_numpy(u.cpu().numpy().real).to(device=device) 457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(re, exp) 458*da0073e9SAndroid Build Coastguard Worker # for the case of contiguous_input, t=u 459*da0073e9SAndroid Build Coastguard Worker # for the case of non contiguous_input, the base still remains 460*da0073e9SAndroid Build Coastguard Worker # t since we are performing a view operation to make the input non-contiguous 461*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, re)) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker im = u.imag 464*da0073e9SAndroid Build Coastguard Worker exp = torch.from_numpy(u.cpu().numpy().imag).to(device=device) 465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(im, exp) 466*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, im)) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker compare_with_numpy() 469*da0073e9SAndroid Build Coastguard Worker compare_with_numpy(contiguous_input=False) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker # ensure storage offset is being correctly set 472*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, dtype=dtype) 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[5:].real, a.real[5:]) 474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[5:].imag, a.imag[5:]) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 477*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 478*da0073e9SAndroid Build Coastguard Worker def test_conj_imag_view(self, device, dtype) -> None: 479*da0073e9SAndroid Build Coastguard Worker t = _make_tensor((4, 5), dtype, device) 480*da0073e9SAndroid Build Coastguard Worker t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device) 481*da0073e9SAndroid Build Coastguard Worker v = t.conj() 482*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, t_numpy_conj) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker if t.is_complex(): 486*da0073e9SAndroid Build Coastguard Worker v_imag = v.imag 487*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v_imag)) 488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v_imag, t_numpy_conj.imag) 489*da0073e9SAndroid Build Coastguard Worker self.assertTrue(v_imag.is_neg()) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 492*da0073e9SAndroid Build Coastguard Worker def test_conj_view_with_shared_memory(self, device) -> None: 493*da0073e9SAndroid Build Coastguard Worker a = _make_tensor((4, 5), torch.cfloat, device) 494*da0073e9SAndroid Build Coastguard Worker b = a.conj() 495*da0073e9SAndroid Build Coastguard Worker c = a.conj() 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.add(a, b), a.add_(b)) 498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.add(b, c), torch.add(b, c, out=a)) 499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.add(b, c), b.add_(c)) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 502*da0073e9SAndroid Build Coastguard Worker @dtypes( 503*da0073e9SAndroid Build Coastguard Worker *product( 504*da0073e9SAndroid Build Coastguard Worker complex_types(), 505*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 506*da0073e9SAndroid Build Coastguard Worker ) 507*da0073e9SAndroid Build Coastguard Worker ) 508*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 509*da0073e9SAndroid Build Coastguard Worker def test_set_real_imag(self, device, dtypes): 510*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, dtype=dtypes[0], device=device) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker new_real = _make_tensor((10,), dtypes[1], device) 513*da0073e9SAndroid Build Coastguard Worker new_imag = _make_tensor((10,), dtypes[1], device) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker x.real = new_real 516*da0073e9SAndroid Build Coastguard Worker x.imag = new_imag 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker if dtypes[1].is_complex: 519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.real, new_real.real, exact_dtype=False) 520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.imag, new_imag.real, exact_dtype=False) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker else: 523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.real, new_real, exact_dtype=False) 524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.imag, new_imag, exact_dtype=False) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker def test_diagonal_view(self, device) -> None: 527*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 528*da0073e9SAndroid Build Coastguard Worker v = torch.diagonal(t) 529*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker v[0] = 0 532*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0], v[0]) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker t = torch.ones((3, 3, 3), device=device) 535*da0073e9SAndroid Build Coastguard Worker v = torch.diagonal(t, offset=1, dim1=1, dim2=2) 536*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0, 1], v[0, 0]) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker def test_select_view(self, device) -> None: 542*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 543*da0073e9SAndroid Build Coastguard Worker v = t.select(0, 2) 544*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker v[0] = 0 547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 0], v[0]) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker # Lazy hasn't implemented unbind yet. 550*da0073e9SAndroid Build Coastguard Worker @skipLazy 551*da0073e9SAndroid Build Coastguard Worker def test_unbind_view(self, device) -> None: 552*da0073e9SAndroid Build Coastguard Worker t = torch.zeros((5, 5), device=device) 553*da0073e9SAndroid Build Coastguard Worker tup = torch.unbind(t) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker for idx, v in enumerate(tup): 556*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker v[0] = idx + 1 559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx, 0], v[0]) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker # TODO: opinfo this or move to unbind's test suite 562*da0073e9SAndroid Build Coastguard Worker def test_unbind(self): 563*da0073e9SAndroid Build Coastguard Worker stacked = torch.randn(3, 10, 10, requires_grad=True) 564*da0073e9SAndroid Build Coastguard Worker x, y, z = stacked.unbind() 565*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(3, 10, 10) 566*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([x, y, z], grad.unbind()) 567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stacked.grad, grad) 568*da0073e9SAndroid Build Coastguard Worker # check that it works with only one gradient provided (#9977) 569*da0073e9SAndroid Build Coastguard Worker for i in range(3): 570*da0073e9SAndroid Build Coastguard Worker stacked = torch.randn(3, 10, 10, requires_grad=True) 571*da0073e9SAndroid Build Coastguard Worker outs = stacked.unbind() 572*da0073e9SAndroid Build Coastguard Worker gi = grad.unbind()[i] 573*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad(outs[i], stacked, gi) 574*da0073e9SAndroid Build Coastguard Worker g_expected = torch.stack( 575*da0073e9SAndroid Build Coastguard Worker [gi if j == i else torch.zeros_like(gi) for j in range(3)], dim=0 576*da0073e9SAndroid Build Coastguard Worker ) 577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g, g_expected) 578*da0073e9SAndroid Build Coastguard Worker # Check with gradcheck 579*da0073e9SAndroid Build Coastguard Worker stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True) 580*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker # TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken, 583*da0073e9SAndroid Build Coastguard Worker # causing asserts to trigger. 584*da0073e9SAndroid Build Coastguard Worker @skipLazy 585*da0073e9SAndroid Build Coastguard Worker def test_expand_view(self, device) -> None: 586*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 1), device=device) 587*da0073e9SAndroid Build Coastguard Worker v = t.expand(5, 5) 588*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker v[2, 2] = 0 591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 0], v[2, 2]) 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker def test_expand_as_view(self, device): 594*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 1), device=device) 595*da0073e9SAndroid Build Coastguard Worker e = torch.empty((5, 5), device=device) 596*da0073e9SAndroid Build Coastguard Worker v = t.expand_as(e) 597*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker v[2, 2] = 0 600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 0], v[2, 2]) 601*da0073e9SAndroid Build Coastguard Worker 602*da0073e9SAndroid Build Coastguard Worker def test_narrow_view(self, device): 603*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 604*da0073e9SAndroid Build Coastguard Worker v = torch.narrow(t, 1, 2, 2) 605*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 2], v[0, 0]) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker def test_permute_view(self, device) -> None: 611*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 612*da0073e9SAndroid Build Coastguard Worker v = t.permute(1, 0) 613*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker def test_transpose_view(self, device): 619*da0073e9SAndroid Build Coastguard Worker for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 620*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 621*da0073e9SAndroid Build Coastguard Worker v = fn(t, 0, 1) 622*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker def test_transpose_inplace_view(self, device): 628*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 629*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 630*da0073e9SAndroid Build Coastguard Worker v = v.swapdims_(0, 1) 631*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 632*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 636*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 637*da0073e9SAndroid Build Coastguard Worker v = v.swapaxes_(0, 1) 638*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 639*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 643*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 644*da0073e9SAndroid Build Coastguard Worker v = v.transpose_(0, 1) 645*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 646*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker def test_t_view(self, device): 650*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 651*da0073e9SAndroid Build Coastguard Worker v = t.t() 652*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker def test_t_inplace_view(self, device): 658*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 659*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 660*da0073e9SAndroid Build Coastguard Worker v = v.t_() 661*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 662*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker def test_T_view(self, device): 666*da0073e9SAndroid Build Coastguard Worker for op in ("T", "H", "mT", "mH"): 667*da0073e9SAndroid Build Coastguard Worker t = torch.ones((5, 5), device=device) 668*da0073e9SAndroid Build Coastguard Worker v = getattr(t, op) 669*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 0], v[0, 1]) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker def test_unfold_view(self, device): 675*da0073e9SAndroid Build Coastguard Worker t = torch.ones(10, device=device) 676*da0073e9SAndroid Build Coastguard Worker v = t.unfold(0, 3, 2) 677*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker v[1, 0] = 0 680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2], v[1, 0]) 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker def test_squeeze_view(self, device): 683*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 1, 5, device=device) 684*da0073e9SAndroid Build Coastguard Worker v = torch.squeeze(t) 685*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 686*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, v._base) 688*da0073e9SAndroid Build Coastguard Worker 689*da0073e9SAndroid Build Coastguard Worker def test_squeeze_inplace_view(self, device): 690*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 691*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 692*da0073e9SAndroid Build Coastguard Worker v = v.squeeze_() 693*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 694*da0073e9SAndroid Build Coastguard Worker v[0, 1] = 0 695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, v._base) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze_view(self, device): 698*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 699*da0073e9SAndroid Build Coastguard Worker v = torch.unsqueeze(t, 1) 700*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker v[0, 0, 1] = 0 703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 1], v[0, 0, 1]) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze_inplace_view(self, device): 706*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 707*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 708*da0073e9SAndroid Build Coastguard Worker v = v.unsqueeze_(1) 709*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 710*da0073e9SAndroid Build Coastguard Worker v[0, 0, 1] = 0 711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 1], v[0, 0, 1]) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker def test_as_strided_view(self, device): 714*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 715*da0073e9SAndroid Build Coastguard Worker v = torch.as_strided(t, (25,), (1,)) 716*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 717*da0073e9SAndroid Build Coastguard Worker 718*da0073e9SAndroid Build Coastguard Worker v[6] = 0 719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker def test_as_strided_inplace_view(self, device): 722*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 723*da0073e9SAndroid Build Coastguard Worker v = t.view_as(t) 724*da0073e9SAndroid Build Coastguard Worker v = v.as_strided_((25,), (1,)) 725*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 726*da0073e9SAndroid Build Coastguard Worker v[6] = 0 727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker def test_as_strided_gradients(self): 730*da0073e9SAndroid Build Coastguard Worker def test(x, prepro_fn, size, strides, offset=None): 731*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.double).detach().requires_grad_() 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker # Check that forward will **not** resize storage because it may 734*da0073e9SAndroid Build Coastguard Worker # cause NaN in output and fail numerical Jacobian check consequently 735*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 736*da0073e9SAndroid Build Coastguard Worker y = prepro_fn(x) if prepro_fn is not None else x 737*da0073e9SAndroid Build Coastguard Worker max_offset = sum((si - 1) * st for si, st in zip(size, strides)) 738*da0073e9SAndroid Build Coastguard Worker max_offset += offset if offset is not None else y.storage_offset() 739*da0073e9SAndroid Build Coastguard Worker assert max_offset < len(y.storage()), "test case resizes storage" 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker def closure(x): 742*da0073e9SAndroid Build Coastguard Worker if prepro_fn is not None: 743*da0073e9SAndroid Build Coastguard Worker x = prepro_fn(x) 744*da0073e9SAndroid Build Coastguard Worker return x.as_strided(size, strides, offset) 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker gradcheck(closure, [x], check_forward_ad=True) 747*da0073e9SAndroid Build Coastguard Worker gradgradcheck(closure, [x]) 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker # test 750*da0073e9SAndroid Build Coastguard Worker test(torch.arange(0, 25), lambda x: x.view(5, 5), [3, 3], [6, 2], 2) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker # test crazy stride at dim with size 1 case 753*da0073e9SAndroid Build Coastguard Worker test(torch.randn(12), None, [1, 2, 1, 5], [0, 5, 100, 1], 2) 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker # test expand case 756*da0073e9SAndroid Build Coastguard Worker test(torch.randn(5), None, [3, 3, 3], [0, 1, 0], 2) 757*da0073e9SAndroid Build Coastguard Worker test(torch.randn(5), None, [3, 3, 3], [0, 0, 0], 4) 758*da0073e9SAndroid Build Coastguard Worker test(torch.randn(5), lambda x: x.expand(5, 5), [5, 5], [0, 1], 0) 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker # test non-expand overlapping case 761*da0073e9SAndroid Build Coastguard Worker test(torch.randn(35), None, [6, 6], [5, 1], 2) 762*da0073e9SAndroid Build Coastguard Worker test(torch.randn(15), None, [3, 2], [3, 6], 2) 763*da0073e9SAndroid Build Coastguard Worker 764*da0073e9SAndroid Build Coastguard Worker # test transpose case 765*da0073e9SAndroid Build Coastguard Worker test(torch.randn(3, 4), None, [4, 3], [1, 4]) 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker # test "getting things outside the input" case 768*da0073e9SAndroid Build Coastguard Worker x = torch.randn(6, 2) 769*da0073e9SAndroid Build Coastguard Worker test(x[3:], None, [3, 2], [2, 1], 0) # should be all zeros 770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[3:].as_strided([3, 2], [2, 1], 0), x[:3]) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker # test select on expanded input case 773*da0073e9SAndroid Build Coastguard Worker test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0) 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker def test_view_view(self, device): 776*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 777*da0073e9SAndroid Build Coastguard Worker v = t.view(25) 778*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker v[6] = 0 781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker def test_view_as_view(self, device): 784*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 785*da0073e9SAndroid Build Coastguard Worker e = torch.empty((25,)) 786*da0073e9SAndroid Build Coastguard Worker v = t.view_as(e) 787*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker v[6] = 0 790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker def test_contiguous_self(self, device): 793*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 794*da0073e9SAndroid Build Coastguard Worker s = t.contiguous() 795*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s is t) 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker @skipMeta 798*da0073e9SAndroid Build Coastguard Worker # self.is_view_of reports false positives for lazy 799*da0073e9SAndroid Build Coastguard Worker @skipLazy 800*da0073e9SAndroid Build Coastguard Worker def test_contiguous_nonview(self, device): 801*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 802*da0073e9SAndroid Build Coastguard Worker nv = t.t().contiguous() 803*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not self.is_view_of(t, nv)) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker nv[0, 0] = 0 806*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t[0, 0], nv[0, 0]) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker def test_reshape_view(self, device): 809*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 810*da0073e9SAndroid Build Coastguard Worker v = torch.reshape(t, (25,)) 811*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker v[6] = 0 814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker def test_reshape_as_view(self, device): 817*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 818*da0073e9SAndroid Build Coastguard Worker e = torch.empty((25,), device=device) 819*da0073e9SAndroid Build Coastguard Worker v = t.reshape_as(e) 820*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker v[6] = 0 823*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[1, 1], v[6]) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker @skipMeta 826*da0073e9SAndroid Build Coastguard Worker # self.is_view_of reports false positives for lazy 827*da0073e9SAndroid Build Coastguard Worker @skipLazy 828*da0073e9SAndroid Build Coastguard Worker def test_reshape_nonview(self, device): 829*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 830*da0073e9SAndroid Build Coastguard Worker nv = torch.reshape(t.t(), (25,)) 831*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not self.is_view_of(t, nv)) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker nv[6] = 0 834*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t[1, 1], nv[6]) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker # This test use as_strided to construct a tensor with overlapping memory, 837*da0073e9SAndroid Build Coastguard Worker # which is not handled by the functionalization pass. 838*da0073e9SAndroid Build Coastguard Worker @skipLazy 839*da0073e9SAndroid Build Coastguard Worker @skipXLA 840*da0073e9SAndroid Build Coastguard Worker def test_flatten_view(self, device): 841*da0073e9SAndroid Build Coastguard Worker def test_writes_propagate(t, v): 842*da0073e9SAndroid Build Coastguard Worker idx_t = (0,) * t.ndim 843*da0073e9SAndroid Build Coastguard Worker idx_v = (0,) * v.ndim 844*da0073e9SAndroid Build Coastguard Worker v[idx_v] = 0 845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx_t], v[idx_v]) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1, 2, 3, 4, device=device) 848*da0073e9SAndroid Build Coastguard Worker v = t.flatten() 849*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 850*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v) 851*da0073e9SAndroid Build Coastguard Worker 852*da0073e9SAndroid Build Coastguard Worker # zero-dimensional tensor 853*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1, device=device) 854*da0073e9SAndroid Build Coastguard Worker v = t.flatten() 855*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v) 856*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) 859*da0073e9SAndroid Build Coastguard Worker v = t.flatten(0, 1) 860*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v) 861*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v)) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: 864*da0073e9SAndroid Build Coastguard Worker t = torch.ones(720, device=device).as_strided( 865*da0073e9SAndroid Build Coastguard Worker (2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0) 866*da0073e9SAndroid Build Coastguard Worker ) 867*da0073e9SAndroid Build Coastguard Worker # [--1--|---2---|-3-] [--1--|----2---|-3-] 868*da0073e9SAndroid Build Coastguard Worker v1 = t.flatten(0, 1) 869*da0073e9SAndroid Build Coastguard Worker v2 = v1.flatten(1, 3) 870*da0073e9SAndroid Build Coastguard Worker v3 = v2.flatten(2, 2) 871*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v1) 872*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v1)) 873*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v2) 874*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v2)) 875*da0073e9SAndroid Build Coastguard Worker test_writes_propagate(t, v3) 876*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of_same_base(t, v3)) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 879*da0073e9SAndroid Build Coastguard Worker def test_flatten_nonview(self, device): 880*da0073e9SAndroid Build Coastguard Worker def assert_is_nonview(t, nv): 881*da0073e9SAndroid Build Coastguard Worker idx_t = (0,) * t.ndim 882*da0073e9SAndroid Build Coastguard Worker idx_nv = (0,) * nv.ndim 883*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not nv._is_view()) 884*da0073e9SAndroid Build Coastguard Worker nv[idx_nv] = 0 885*da0073e9SAndroid Build Coastguard Worker if device != "meta": 886*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t[idx_t], nv[idx_nv]) 887*da0073e9SAndroid Build Coastguard Worker 888*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) 889*da0073e9SAndroid Build Coastguard Worker nv = t.flatten(1, 3) 890*da0073e9SAndroid Build Coastguard Worker assert_is_nonview(t, nv) 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2, 2, device=device).T 893*da0073e9SAndroid Build Coastguard Worker nv = t.flatten() 894*da0073e9SAndroid Build Coastguard Worker assert_is_nonview(t, nv) 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker # flatten returns the original object if start_dim=end_dim 897*da0073e9SAndroid Build Coastguard Worker t = t = torch.ones(2, 2, device=device) 898*da0073e9SAndroid Build Coastguard Worker nv = t.flatten(1, 1) 899*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t is nv) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker def test_basic_indexing_slice_view(self, device): 902*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 903*da0073e9SAndroid Build Coastguard Worker v = t[:2, :3] 904*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0], v[0, 0]) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker def test_basic_indexing_ellipses_view(self, device): 910*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 911*da0073e9SAndroid Build Coastguard Worker v = t[..., :2] 912*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 0], v[0, 0]) 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker def test_basic_indexing_newaxis_view(self, device): 918*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5, device=device) 919*da0073e9SAndroid Build Coastguard Worker v = t[None, :2, 3] 920*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker v[0, 0] = 0 923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[0, 3], v[0, 0]) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing_nonview(self, device): 926*da0073e9SAndroid Build Coastguard Worker t = torch.ones(3, 3, device=device) 927*da0073e9SAndroid Build Coastguard Worker rows = torch.tensor([[0, 0], [2, 2]], device=device) 928*da0073e9SAndroid Build Coastguard Worker cols = torch.tensor([[0, 1], [2, 2]], device=device) 929*da0073e9SAndroid Build Coastguard Worker nv = t[rows, cols] 930*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not self.is_view_of(t, nv)) 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker nv[1, 1] = 0 933*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t[2, 2], nv[1, 1]) 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 936*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds" 937*da0073e9SAndroid Build Coastguard Worker ) 938*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing_assignment(self, device): 939*da0073e9SAndroid Build Coastguard Worker t = torch.ones(3, 3, device=device) 940*da0073e9SAndroid Build Coastguard Worker rows = torch.tensor([[0, 0], [2, 2]], device=device) 941*da0073e9SAndroid Build Coastguard Worker cols = torch.tensor([[0, 1], [2, 2]], device=device) 942*da0073e9SAndroid Build Coastguard Worker t[rows, cols] = 0 943*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[2, 2], 0) 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") 946*da0073e9SAndroid Build Coastguard Worker def test_chunk_view(self, device): 947*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(3, 3, device=device) 948*da0073e9SAndroid Build Coastguard Worker l = torch.chunk(t, 3) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker for idx, v in enumerate(l): 951*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker v[0, 0] = idx + 1 954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx, 0], v[0, 0]) 955*da0073e9SAndroid Build Coastguard Worker 956*da0073e9SAndroid Build Coastguard Worker @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") 957*da0073e9SAndroid Build Coastguard Worker def test_split_view(self, device): 958*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(3, 3, device=device) 959*da0073e9SAndroid Build Coastguard Worker l = torch.split(t, [1, 1, 1]) 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker for idx, v in enumerate(l): 962*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, v)) 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker v[0, 0] = idx + 1 965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx, 0], v[0, 0]) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker def test_movedim_view(self, device): 968*da0073e9SAndroid Build Coastguard Worker def run_test(device, op): 969*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(3, 3, device=device) 970*da0073e9SAndroid Build Coastguard Worker out = op(t) 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(t, out)) 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Worker # Randomly change values in output 975*da0073e9SAndroid Build Coastguard Worker # and verify that original is changed 976*da0073e9SAndroid Build Coastguard Worker # as well. 977*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 978*da0073e9SAndroid Build Coastguard Worker idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) 979*da0073e9SAndroid Build Coastguard Worker out[idx_1, idx_2] = random.random() 980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker for fn in [torch.movedim, torch.moveaxis]: 983*da0073e9SAndroid Build Coastguard Worker op = partial(fn, source=(0, 1), destination=(1, 0)) 984*da0073e9SAndroid Build Coastguard Worker run_test(device, op) 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker op = partial(fn, source=0, destination=1) 987*da0073e9SAndroid Build Coastguard Worker run_test(device, op) 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Worker # Testing that the generated view_copy kernel and its derivative are implemented correctly 990*da0073e9SAndroid Build Coastguard Worker def test_view_copy(self, device): 991*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, device=device, requires_grad=True) 992*da0073e9SAndroid Build Coastguard Worker a_ref = a.clone().detach().requires_grad_() 993*da0073e9SAndroid Build Coastguard Worker a_view = a_ref.view(2, 2) 994*da0073e9SAndroid Build Coastguard Worker a_view_copy = torch.view_copy(a, (2, 2)) 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker # view_copy ops don't preserve view relationship 997*da0073e9SAndroid Build Coastguard Worker self.assertTrue(self.is_view_of(a_ref, a_view)) 998*da0073e9SAndroid Build Coastguard Worker self.assertFalse(self.is_view_of(a, a_view_copy)) 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker a_view_copy.sum().backward() 1001*da0073e9SAndroid Build Coastguard Worker a_view.sum().backward() 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker # forward and backward give the same shape + result 1004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_view_copy, a_view) 1005*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, a_ref.grad) 1006*da0073e9SAndroid Build Coastguard Worker 1007*da0073e9SAndroid Build Coastguard Worker # Testing that the output of a view_copy kernel (by default) is contiguous. 1008*da0073e9SAndroid Build Coastguard Worker def test_view_copy_output_contiguous(self, device): 1009*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last) 1010*da0073e9SAndroid Build Coastguard Worker b = torch.ops.aten.slice_copy(a, 0, 0, 2) 1011*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.is_contiguous()) 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker def test_view_copy_out(self, device): 1014*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, device=device) 1015*da0073e9SAndroid Build Coastguard Worker out = torch.empty(2, device=device) 1016*da0073e9SAndroid Build Coastguard Worker 1017*da0073e9SAndroid Build Coastguard Worker torch.diagonal_copy(a, out=out) 1018*da0073e9SAndroid Build Coastguard Worker expected = torch.diagonal_copy(a) 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, device=device) 1023*da0073e9SAndroid Build Coastguard Worker out1 = torch.empty(2, device=device) 1024*da0073e9SAndroid Build Coastguard Worker out2 = torch.empty(2, device=device) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker torch.split_copy(a, 2, out=(out1, out2)) 1027*da0073e9SAndroid Build Coastguard Worker expected1, expected2 = torch.split_copy(a, 2) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected1, out1) 1030*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected2, out2) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker 1033*da0073e9SAndroid Build Coastguard Workerclass TestOldViewOps(TestCase): 1034*da0073e9SAndroid Build Coastguard Worker def test_ravel(self, device): 1035*da0073e9SAndroid Build Coastguard Worker def _test_ravel(tensors, size, nc=False): 1036*da0073e9SAndroid Build Coastguard Worker for src in tensors: 1037*da0073e9SAndroid Build Coastguard Worker # Continuous Tensor -> View 1038*da0073e9SAndroid Build Coastguard Worker flat = src.ravel() 1039*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat.shape, torch.Size([size])) 1040*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.view(-1), flat) 1041*da0073e9SAndroid Build Coastguard Worker self.assertIs(flat._base, src) 1042*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat.is_contiguous()) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker # Non-continuous Tensor -> Copy 1045*da0073e9SAndroid Build Coastguard Worker if nc: 1046*da0073e9SAndroid Build Coastguard Worker nc_src = src.t() 1047*da0073e9SAndroid Build Coastguard Worker nc_flat = nc_src.ravel() 1048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nc_flat.shape, torch.Size([size])) 1049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nc_src.contiguous().view(-1), nc_flat) 1050*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(nc_flat._base, src) 1051*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nc_flat.is_contiguous()) 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker # Test that flatten returns 1-dim tensor when given a 0-dim tensor 1054*da0073e9SAndroid Build Coastguard Worker zero_dim_tensor = torch.tensor(123, device=device) 1055*da0073e9SAndroid Build Coastguard Worker flat0 = zero_dim_tensor.ravel() 1056*da0073e9SAndroid Build Coastguard Worker one_dim_tensor = torch.tensor([123], device=device) 1057*da0073e9SAndroid Build Coastguard Worker flat1 = zero_dim_tensor.ravel() 1058*da0073e9SAndroid Build Coastguard Worker nc_ones_tensor = torch.ones(10, device=device)[::2] 1059*da0073e9SAndroid Build Coastguard Worker flat2 = nc_ones_tensor.ravel() 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_tensor.shape, torch.Size([])) 1062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0.shape, torch.Size([1])) 1063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(one_dim_tensor.shape, torch.Size([1])) 1064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat1.shape, torch.Size([1])) 1065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nc_ones_tensor.shape, torch.Size([5])) 1066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat2.shape, torch.Size([5])) 1067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0, one_dim_tensor) 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0, flat1) 1069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0.shape, flat1.shape) 1070*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat0.is_contiguous()) 1071*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat1.is_contiguous()) 1072*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat2.is_contiguous()) 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker # Test both float tensor and quantized tensor 1075*da0073e9SAndroid Build Coastguard Worker tensors = [ 1076*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 5, 5, 5, device=device), 1077*da0073e9SAndroid Build Coastguard Worker torch._empty_affine_quantized( 1078*da0073e9SAndroid Build Coastguard Worker [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device 1079*da0073e9SAndroid Build Coastguard Worker ), 1080*da0073e9SAndroid Build Coastguard Worker ] 1081*da0073e9SAndroid Build Coastguard Worker _test_ravel(tensors, 625) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker tensors = [ 1084*da0073e9SAndroid Build Coastguard Worker torch.randn(0, 2, 3, device=device), 1085*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 0, 2, device=device), 1086*da0073e9SAndroid Build Coastguard Worker torch._empty_affine_quantized( 1087*da0073e9SAndroid Build Coastguard Worker [0, 2, 3], scale=2, zero_point=3, dtype=torch.quint8, device=device 1088*da0073e9SAndroid Build Coastguard Worker ), 1089*da0073e9SAndroid Build Coastguard Worker torch._empty_affine_quantized( 1090*da0073e9SAndroid Build Coastguard Worker [3, 0, 2], scale=2, zero_point=3, dtype=torch.quint8, device=device 1091*da0073e9SAndroid Build Coastguard Worker ), 1092*da0073e9SAndroid Build Coastguard Worker ] 1093*da0073e9SAndroid Build Coastguard Worker _test_ravel(tensors, 0) 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Worker tensors = [ 1096*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 5, device=device), 1097*da0073e9SAndroid Build Coastguard Worker torch._empty_affine_quantized( 1098*da0073e9SAndroid Build Coastguard Worker [5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device 1099*da0073e9SAndroid Build Coastguard Worker ), 1100*da0073e9SAndroid Build Coastguard Worker ] 1101*da0073e9SAndroid Build Coastguard Worker _test_ravel(tensors, 25, True) 1102*da0073e9SAndroid Build Coastguard Worker 1103*da0073e9SAndroid Build Coastguard Worker # TODO: this should be refactored into the view ops test suite 1104*da0073e9SAndroid Build Coastguard Worker def test_empty_reshape(self, device): 1105*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0, 6, device=device) 1106*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) 1107*da0073e9SAndroid Build Coastguard Worker # should be viewable -- i.e. data_ptr is the same. 1108*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) 1109*da0073e9SAndroid Build Coastguard Worker 1110*da0073e9SAndroid Build Coastguard Worker # match NumPy semantics -- don't infer the size of dimension with a degree of freedom 1111*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 1114*da0073e9SAndroid Build Coastguard Worker def test_expand(self, device): 1115*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(1, 8, 1, device=device) 1116*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.rand(5, device=device) 1117*da0073e9SAndroid Build Coastguard Worker template = torch.rand(4, 8, 5, device=device) 1118*da0073e9SAndroid Build Coastguard Worker target = template.size() 1119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand_as(template).size(), target) 1120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand(4, 8, 5).size(), target) 1121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand(target).size(), target) 1122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand_as(template).size(), target) 1123*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand(4, 8, 5).size(), target) 1124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand(target).size(), target) 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker # test double expand 1127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker # test non-contiguous 1130*da0073e9SAndroid Build Coastguard Worker noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] 1131*da0073e9SAndroid Build Coastguard Worker self.assertFalse(noncontig.is_contiguous()) 1132*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1133*da0073e9SAndroid Build Coastguard Worker noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1) 1134*da0073e9SAndroid Build Coastguard Worker ) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker # make sure it's compatible with unsqueeze 1137*da0073e9SAndroid Build Coastguard Worker expanded = tensor2.expand(1, 1, 5) 1138*da0073e9SAndroid Build Coastguard Worker unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) 1139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded, unsqueezed) 1140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expanded.stride(), unsqueezed.stride()) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker # test -1 as target size 1143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) 1144*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker # test expanding empty to empty 1147*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1148*da0073e9SAndroid Build Coastguard Worker torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device) 1149*da0073e9SAndroid Build Coastguard Worker ) 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker # TODO: this should be refactored into the view ops test suite 1152*da0073e9SAndroid Build Coastguard Worker def test_view_empty(self, device): 1153*da0073e9SAndroid Build Coastguard Worker x = torch.randn(0, 6, device=device) 1154*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Worker # TODO: this should be refactored into the view ops test suite 1157*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1158*da0073e9SAndroid Build Coastguard Worker def test_reshape(self, device): 1159*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, device=device) 1160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) 1161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) 1162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) 1163*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 4, 4, device=device)[:, 0, :] 1166*da0073e9SAndroid Build Coastguard Worker # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape 1167*da0073e9SAndroid Build Coastguard Worker if device != "meta": 1168*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) 1169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) 1170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker s = torch.randn((), device=device) 1173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) 1174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.reshape(-1).shape, (1,)) 1175*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: s.reshape(2)) 1176*da0073e9SAndroid Build Coastguard Worker 1177*da0073e9SAndroid Build Coastguard Worker empty = torch.tensor([], device=device) 1178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty, empty.reshape(-1)) 1179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty, empty.reshape([0])) 1180*da0073e9SAndroid Build Coastguard Worker # TODO: fix these once we have multi-dimensional empty tensors 1181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) 1182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) 1183*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: empty.reshape(1)) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, device=device) 1186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) 1187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) 1188*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1189*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)) 1190*da0073e9SAndroid Build Coastguard Worker ) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker def test_flatten(self, device): 1193*da0073e9SAndroid Build Coastguard Worker # Test that flatten returns 1-dim tensor when given a 0-dim tensor 1194*da0073e9SAndroid Build Coastguard Worker zero_dim_tensor = torch.tensor(123, device=device) 1195*da0073e9SAndroid Build Coastguard Worker flat0 = zero_dim_tensor.flatten() 1196*da0073e9SAndroid Build Coastguard Worker one_dim_tensor = torch.tensor([123], device=device) 1197*da0073e9SAndroid Build Coastguard Worker flat1 = zero_dim_tensor.flatten() 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(zero_dim_tensor.shape, torch.Size([])) 1200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0.shape, torch.Size([1])) 1201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(one_dim_tensor.shape, torch.Size([1])) 1202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat1.shape, torch.Size([1])) 1203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0, one_dim_tensor) 1204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0, flat1) 1205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat0.shape, flat1.shape) 1206*da0073e9SAndroid Build Coastguard Worker 1207*da0073e9SAndroid Build Coastguard Worker # Test both float tensor and quantized tensor 1208*da0073e9SAndroid Build Coastguard Worker tensors = [ 1209*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 5, 5, 5, device=device), 1210*da0073e9SAndroid Build Coastguard Worker torch._empty_affine_quantized( 1211*da0073e9SAndroid Build Coastguard Worker [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device 1212*da0073e9SAndroid Build Coastguard Worker ), 1213*da0073e9SAndroid Build Coastguard Worker ] 1214*da0073e9SAndroid Build Coastguard Worker for src in tensors: 1215*da0073e9SAndroid Build Coastguard Worker flat = src.flatten(0, -1) 1216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat.shape, torch.Size([625])) 1217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.view(-1), flat.view(-1)) 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker flat = src.flatten(0, 2) 1220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat.shape, torch.Size([125, 5])) 1221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.view(-1), flat.view(-1)) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker flat = src.flatten(0, 1) 1224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat.shape, torch.Size([25, 5, 5])) 1225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.view(-1), flat.view(-1)) 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker flat = src.flatten(1, 2) 1228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat.shape, torch.Size([5, 25, 5])) 1229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.view(-1), flat.view(-1)) 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker flat = src.flatten(2, 3) 1232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat.shape, torch.Size([5, 5, 25])) 1233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.view(-1), flat.view(-1)) 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker flat = src.flatten(-2, -1) 1236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat.shape, torch.Size([5, 5, 25])) 1237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.view(-1), flat.view(-1)) 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker flat = src.flatten(2, 2) 1240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat, src) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker # out of bounds index 1243*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1244*da0073e9SAndroid Build Coastguard Worker src.flatten(5, 10) 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker # invalid start and end 1247*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1248*da0073e9SAndroid Build Coastguard Worker RuntimeError, "start_dim cannot come after end_dim" 1249*da0073e9SAndroid Build Coastguard Worker ): 1250*da0073e9SAndroid Build Coastguard Worker src.flatten(2, 0) 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker # TODO: update to work on CUDA, too 1253*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1254*da0073e9SAndroid Build Coastguard Worker def test_narrow(self, device): 1255*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 1256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]])) 1257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]])) 1258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]])) 1259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]])) 1260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]])) 1261*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1262*da0073e9SAndroid Build Coastguard Worker x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 1263*da0073e9SAndroid Build Coastguard Worker ) 1264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]])) 1265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]])) 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker # TODO: update to work on CUDA, too 1268*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1269*da0073e9SAndroid Build Coastguard Worker def test_narrow_tensor(self, device): 1270*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 1271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]])) 1272*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 1273*da0073e9SAndroid Build Coastguard Worker x.narrow(0, torch.tensor(0.0), 1) 1274*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 1275*da0073e9SAndroid Build Coastguard Worker x.narrow(0, torch.tensor([0]), 1) 1276*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 1277*da0073e9SAndroid Build Coastguard Worker x.narrow(0, torch.tensor([0, 1]), 1) 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker # TODO: make work on CUDA, too 1280*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1281*da0073e9SAndroid Build Coastguard Worker def test_t(self, device): 1282*da0073e9SAndroid Build Coastguard Worker # Test 0D tensors 1283*da0073e9SAndroid Build Coastguard Worker x = torch.randn(()) 1284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 1285*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 1286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker # Test 1D tensors 1289*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4) 1290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 1291*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 1292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.t()) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker # Test 2D tensors 1295*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2)) 1296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.t(), x.transpose(0, 1)) 1297*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 1298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.t(), x.transpose(0, 1)) 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker # Test 3D tensor 1301*da0073e9SAndroid Build Coastguard Worker x = torch.rand((2, 2, 2)) 1302*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1303*da0073e9SAndroid Build Coastguard Worker RuntimeError, "expects a tensor with <= 2 dimensions, but self is 3D" 1304*da0073e9SAndroid Build Coastguard Worker ): 1305*da0073e9SAndroid Build Coastguard Worker x.t() 1306*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 1307*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1308*da0073e9SAndroid Build Coastguard Worker RuntimeError, "expects a tensor with <= 2 sparse and 0 dense dimensions" 1309*da0073e9SAndroid Build Coastguard Worker ): 1310*da0073e9SAndroid Build Coastguard Worker x.t() 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1313*da0073e9SAndroid Build Coastguard Worker def test_split(self, device): 1314*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(7, 4) 1315*da0073e9SAndroid Build Coastguard Worker split_size = 3 1316*da0073e9SAndroid Build Coastguard Worker dim = 0 1317*da0073e9SAndroid Build Coastguard Worker target_sizes = ([3, 4], [3, 4], [1, 4]) 1318*da0073e9SAndroid Build Coastguard Worker splits = tensor.split(split_size, dim) 1319*da0073e9SAndroid Build Coastguard Worker start = 0 1320*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 1321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 1322*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1323*da0073e9SAndroid Build Coastguard Worker tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1324*da0073e9SAndroid Build Coastguard Worker ) 1325*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker # Variable sections split 1328*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(20, 10) 1329*da0073e9SAndroid Build Coastguard Worker dim = 0 1330*da0073e9SAndroid Build Coastguard Worker split_sizes = [5, 5, 10] 1331*da0073e9SAndroid Build Coastguard Worker target_sizes = [[5, 10], [5, 10], [10, 10]] 1332*da0073e9SAndroid Build Coastguard Worker splits = tensor.split(split_sizes, dim) 1333*da0073e9SAndroid Build Coastguard Worker start = 0 1334*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 1335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 1336*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1337*da0073e9SAndroid Build Coastguard Worker tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1338*da0073e9SAndroid Build Coastguard Worker ) 1339*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 1340*da0073e9SAndroid Build Coastguard Worker 1341*da0073e9SAndroid Build Coastguard Worker split_sizes = [2, 2, 6] 1342*da0073e9SAndroid Build Coastguard Worker target_sizes = ([20, 2], [20, 2], [20, 6]) 1343*da0073e9SAndroid Build Coastguard Worker dim = 1 1344*da0073e9SAndroid Build Coastguard Worker splits = tensor.split(split_sizes, dim) 1345*da0073e9SAndroid Build Coastguard Worker start = 0 1346*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 1347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 1348*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1349*da0073e9SAndroid Build Coastguard Worker tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1350*da0073e9SAndroid Build Coastguard Worker ) 1351*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1354*da0073e9SAndroid Build Coastguard Worker def test_chunk(self, device): 1355*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(4, 7) 1356*da0073e9SAndroid Build Coastguard Worker num_chunks = 3 1357*da0073e9SAndroid Build Coastguard Worker dim = 1 1358*da0073e9SAndroid Build Coastguard Worker target_sizes = ([4, 3], [4, 3], [4, 1]) 1359*da0073e9SAndroid Build Coastguard Worker splits = tensor.chunk(num_chunks, dim) 1360*da0073e9SAndroid Build Coastguard Worker start = 0 1361*da0073e9SAndroid Build Coastguard Worker for target_size, split in zip(target_sizes, splits): 1362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.size(), target_size) 1363*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1364*da0073e9SAndroid Build Coastguard Worker tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1365*da0073e9SAndroid Build Coastguard Worker ) 1366*da0073e9SAndroid Build Coastguard Worker start = start + target_size[dim] 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker # Invalid chunk sizes 1369*da0073e9SAndroid Build Coastguard Worker error_regex = "chunk expects.*greater than 0" 1370*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 1371*da0073e9SAndroid Build Coastguard Worker tensor.chunk(0) 1372*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 1373*da0073e9SAndroid Build Coastguard Worker tensor.chunk(-2) 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker # TODO: make work on CUDA, too 1376*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 1377*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1378*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze(self, device) -> None: 1379*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4) 1380*da0073e9SAndroid Build Coastguard Worker y = x.unsqueeze(1) 1381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.view(2, 1, 3, 4)) 1382*da0073e9SAndroid Build Coastguard Worker y = x.clone().unsqueeze_(2) 1383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.view(2, 3, 1, 4)) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker x = x[:, 1] 1386*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 1387*da0073e9SAndroid Build Coastguard Worker y = x.unsqueeze(1) 1388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.contiguous().view(2, 1, 4)) 1389*da0073e9SAndroid Build Coastguard Worker y = x.clone().unsqueeze_(2) 1390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.contiguous().view(2, 4, 1)) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) 1393*da0073e9SAndroid Build Coastguard Worker def test_big_transpose(self, device): 1394*da0073e9SAndroid Build Coastguard Worker t = torch.rand(456, 789, device=device) 1395*da0073e9SAndroid Build Coastguard Worker t1 = t.t().contiguous() 1396*da0073e9SAndroid Build Coastguard Worker t2 = torch.from_numpy(t.cpu().numpy().transpose()) 1397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2) 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker def test_T(self, device): 1400*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4, device=device) 1401*da0073e9SAndroid Build Coastguard Worker t1 = a.T 1402*da0073e9SAndroid Build Coastguard Worker t2 = a.permute(2, 1, 0) 1403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2, t1) 1404*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, device=device) 1405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, b.T) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 1408*da0073e9SAndroid Build Coastguard Worker def test_transposes(self, device, dtype): 1409*da0073e9SAndroid Build Coastguard Worker for op in ("T", "H", "mT", "mH", "adjoint"): 1410*da0073e9SAndroid Build Coastguard Worker shapes = ( 1411*da0073e9SAndroid Build Coastguard Worker ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),) 1412*da0073e9SAndroid Build Coastguard Worker ) 1413*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 1414*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 1415*da0073e9SAndroid Build Coastguard Worker t1 = getattr(a, op) 1416*da0073e9SAndroid Build Coastguard Worker if op == "adjoint": 1417*da0073e9SAndroid Build Coastguard Worker t1 = t1() 1418*da0073e9SAndroid Build Coastguard Worker t2 = a 1419*da0073e9SAndroid Build Coastguard Worker t2 = t2.transpose(-2, -1) 1420*da0073e9SAndroid Build Coastguard Worker if op[-1] == "H" or op == "adjoint": 1421*da0073e9SAndroid Build Coastguard Worker t2 = t2.conj() 1422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2, t1) 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 1425*da0073e9SAndroid Build Coastguard Worker def test_transposes_errors(self, device, dtype): 1426*da0073e9SAndroid Build Coastguard Worker for op in ("H", "mT", "mH", "adjoint"): 1427*da0073e9SAndroid Build Coastguard Worker shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) 1428*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 1429*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 1430*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "only supported on matrices"): 1431*da0073e9SAndroid Build Coastguard Worker t1 = getattr(a, op) 1432*da0073e9SAndroid Build Coastguard Worker if op == "adjoint": 1433*da0073e9SAndroid Build Coastguard Worker t1 = t1() 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker def test_python_types(self, device): 1436*da0073e9SAndroid Build Coastguard Worker a1 = torch.randn((1, 2), device=device, dtype=torch.float64) 1437*da0073e9SAndroid Build Coastguard Worker a2 = torch.randn((1, 2), device=device, dtype=float) 1438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.dtype, a2.dtype) 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker b1 = torch.arange(10, 20, dtype=torch.int64, device=device) 1441*da0073e9SAndroid Build Coastguard Worker b2 = torch.arange(10, 20, dtype=int, device=device) 1442*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b1.dtype, b2.dtype) 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker c1 = torch.tensor([True, False], dtype=torch.bool, device=device) 1445*da0073e9SAndroid Build Coastguard Worker c2 = torch.tensor([True, False], dtype=bool, device=device) 1446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c1.dtype, c2.dtype) 1447*da0073e9SAndroid Build Coastguard Worker 1448*da0073e9SAndroid Build Coastguard Worker # TODO: is resize best put in test_view_ops? 1449*da0073e9SAndroid Build Coastguard Worker def test_resize_as_preserves_strides(self, device): 1450*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 3).t() 1451*da0073e9SAndroid Build Coastguard Worker old_strides = x.stride() 1452*da0073e9SAndroid Build Coastguard Worker x.resize_as_(x) 1453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.stride(), old_strides) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker def test_memory_format_resize_as(self, device): 1456*da0073e9SAndroid Build Coastguard Worker def test_helper(shape, memory_format, device): 1457*da0073e9SAndroid Build Coastguard Worker xc = torch.randn(shape, device=device).contiguous( 1458*da0073e9SAndroid Build Coastguard Worker memory_format=memory_format 1459*da0073e9SAndroid Build Coastguard Worker ) 1460*da0073e9SAndroid Build Coastguard Worker flat = torch.randn(xc.numel(), device=device) 1461*da0073e9SAndroid Build Coastguard Worker flat.resize_as_(xc, memory_format=torch.preserve_format) 1462*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 1463*da0073e9SAndroid Build Coastguard Worker 1464*da0073e9SAndroid Build Coastguard Worker test_helper((10, 3, 32, 32), torch.channels_last, device) 1465*da0073e9SAndroid Build Coastguard Worker test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device) 1466*da0073e9SAndroid Build Coastguard Worker 1467*da0073e9SAndroid Build Coastguard Worker def test_memory_format_resize_(self, device): 1468*da0073e9SAndroid Build Coastguard Worker def test_helper(shape, numel, memory_format, device): 1469*da0073e9SAndroid Build Coastguard Worker flat = torch.randn(numel, device=device) 1470*da0073e9SAndroid Build Coastguard Worker flat.resize_(shape, memory_format=memory_format) 1471*da0073e9SAndroid Build Coastguard Worker self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Worker test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device) 1474*da0073e9SAndroid Build Coastguard Worker test_helper( 1475*da0073e9SAndroid Build Coastguard Worker (3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device 1476*da0073e9SAndroid Build Coastguard Worker ) 1477*da0073e9SAndroid Build Coastguard Worker 1478*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1479*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex128) 1480*da0073e9SAndroid Build Coastguard Worker def test_transpose_invalid(self, device, dtype): 1481*da0073e9SAndroid Build Coastguard Worker for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 1482*da0073e9SAndroid Build Coastguard Worker shape = _rand_shape(4, min_size=5, max_size=10) 1483*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, False) 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker # Invalid `source` and `destination` dimension 1486*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1487*da0073e9SAndroid Build Coastguard Worker fn(x, 5, 0) 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1490*da0073e9SAndroid Build Coastguard Worker fn(x, 0, 5) 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex128) 1493*da0073e9SAndroid Build Coastguard Worker def test_transpose_vs_numpy(self, device, dtype): 1494*da0073e9SAndroid Build Coastguard Worker for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 1495*da0073e9SAndroid Build Coastguard Worker for nd in range(5): 1496*da0073e9SAndroid Build Coastguard Worker shape = _rand_shape(nd, min_size=5, max_size=10) 1497*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal=False) 1498*da0073e9SAndroid Build Coastguard Worker for random_negative in [True, False]: 1499*da0073e9SAndroid Build Coastguard Worker for src_dim, dst_dim in permutations(range(nd), r=2): 1500*da0073e9SAndroid Build Coastguard Worker random_prob = random.random() 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker if random_negative and random_prob > 0.66: 1503*da0073e9SAndroid Build Coastguard Worker src_dim = src_dim - nd 1504*da0073e9SAndroid Build Coastguard Worker elif random_negative and random_prob > 0.33: 1505*da0073e9SAndroid Build Coastguard Worker dst_dim = dst_dim - nd 1506*da0073e9SAndroid Build Coastguard Worker elif random_negative: 1507*da0073e9SAndroid Build Coastguard Worker src_dim = src_dim - nd 1508*da0073e9SAndroid Build Coastguard Worker dst_dim = dst_dim - nd 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker partial_map = { 1511*da0073e9SAndroid Build Coastguard Worker torch.swapdims: partial( 1512*da0073e9SAndroid Build Coastguard Worker torch.swapdims, dim0=src_dim, dim1=dst_dim 1513*da0073e9SAndroid Build Coastguard Worker ), 1514*da0073e9SAndroid Build Coastguard Worker torch.swapaxes: partial( 1515*da0073e9SAndroid Build Coastguard Worker torch.swapaxes, axis0=src_dim, axis1=dst_dim 1516*da0073e9SAndroid Build Coastguard Worker ), 1517*da0073e9SAndroid Build Coastguard Worker torch.transpose: partial( 1518*da0073e9SAndroid Build Coastguard Worker torch.transpose, dim0=src_dim, dim1=dst_dim 1519*da0073e9SAndroid Build Coastguard Worker ), 1520*da0073e9SAndroid Build Coastguard Worker } 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Worker torch_fn = partial_map[fn] 1523*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.swapaxes, axis1=src_dim, axis2=dst_dim) 1524*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 1525*da0073e9SAndroid Build Coastguard Worker torch_fn, np_fn, x, device=None, dtype=None 1526*da0073e9SAndroid Build Coastguard Worker ) 1527*da0073e9SAndroid Build Coastguard Worker 1528*da0073e9SAndroid Build Coastguard Worker # Move dim to same position 1529*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5, 7, 11) 1530*da0073e9SAndroid Build Coastguard Worker partial_map = { 1531*da0073e9SAndroid Build Coastguard Worker torch.swapdims: partial(torch.swapdims, dim0=0, dim1=0), 1532*da0073e9SAndroid Build Coastguard Worker torch.swapaxes: partial(torch.swapaxes, axis0=0, axis1=0), 1533*da0073e9SAndroid Build Coastguard Worker torch.transpose: partial(torch.transpose, dim0=0, dim1=0), 1534*da0073e9SAndroid Build Coastguard Worker } 1535*da0073e9SAndroid Build Coastguard Worker torch_fn = partial_map[fn] 1536*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.swapaxes, axis1=0, axis2=0) 1537*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker def _test_atleast_dim(self, torch_fn, np_fn, device, dtype): 1540*da0073e9SAndroid Build Coastguard Worker for ndims in range(0, 5): 1541*da0073e9SAndroid Build Coastguard Worker shape = _rand_shape(ndims, min_size=5, max_size=10) 1542*da0073e9SAndroid Build Coastguard Worker for n in range(ndims + 1): 1543*da0073e9SAndroid Build Coastguard Worker for with_extremal in [False, True]: 1544*da0073e9SAndroid Build Coastguard Worker for contiguous in [False, True]: 1545*da0073e9SAndroid Build Coastguard Worker # Generate Input. 1546*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal) 1547*da0073e9SAndroid Build Coastguard Worker if contiguous: 1548*da0073e9SAndroid Build Coastguard Worker x = x.T 1549*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 1550*da0073e9SAndroid Build Coastguard Worker torch_fn, np_fn, x, device=None, dtype=None 1551*da0073e9SAndroid Build Coastguard Worker ) 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker # Compare sequence input 1554*da0073e9SAndroid Build Coastguard Worker torch_sequence_x = (x,) * random.randint(3, 10) 1555*da0073e9SAndroid Build Coastguard Worker np_sequence_x = tuple( 1556*da0073e9SAndroid Build Coastguard Worker np.array(x.detach().cpu().numpy()) for x in torch_sequence_x 1557*da0073e9SAndroid Build Coastguard Worker ) 1558*da0073e9SAndroid Build Coastguard Worker torch_res = torch_fn(*torch_sequence_x) 1559*da0073e9SAndroid Build Coastguard Worker np_res = np_fn(*np_sequence_x) 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker torch_res = tuple(x.cpu() for x in torch_res) 1562*da0073e9SAndroid Build Coastguard Worker np_res = tuple(torch.from_numpy(x) for x in np_res) 1563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_res, torch_res) 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker # TODO: are these view ops? 1566*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half)) 1567*da0073e9SAndroid Build Coastguard Worker def test_atleast(self, device, dtype): 1568*da0073e9SAndroid Build Coastguard Worker self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype) 1569*da0073e9SAndroid Build Coastguard Worker self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) 1570*da0073e9SAndroid Build Coastguard Worker self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype) 1571*da0073e9SAndroid Build Coastguard Worker 1572*da0073e9SAndroid Build Coastguard Worker # TODO: OpInfo this 1573*da0073e9SAndroid Build Coastguard Worker def _test_atleast(self, device, torch_fn): 1574*da0073e9SAndroid Build Coastguard Worker # 0-dim 1575*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) 1576*da0073e9SAndroid Build Coastguard Worker 1577*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: torch_fn(x), s) 1578*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: torch_fn(x), s) 1579*da0073e9SAndroid Build Coastguard Worker 1580*da0073e9SAndroid Build Coastguard Worker # 1-dim 1581*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4, dtype=torch.double, requires_grad=True) 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: torch_fn(x), a) 1584*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: torch_fn(x), a) 1585*da0073e9SAndroid Build Coastguard Worker 1586*da0073e9SAndroid Build Coastguard Worker # 2,3,4-dim 1587*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4, 3, dtype=torch.double, requires_grad=True) 1588*da0073e9SAndroid Build Coastguard Worker c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True) 1589*da0073e9SAndroid Build Coastguard Worker d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True) 1590*da0073e9SAndroid Build Coastguard Worker 1591*da0073e9SAndroid Build Coastguard Worker input_tuple = (s, a, b, c, d) 1592*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 1593*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker def test_atleast_gradient(self, device): 1596*da0073e9SAndroid Build Coastguard Worker self._test_atleast(device, torch.atleast_1d) 1597*da0073e9SAndroid Build Coastguard Worker self._test_atleast(device, torch.atleast_2d) 1598*da0073e9SAndroid Build Coastguard Worker self._test_atleast(device, torch.atleast_3d) 1599*da0073e9SAndroid Build Coastguard Worker 1600*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1601*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1602*da0073e9SAndroid Build Coastguard Worker def test_broadcast_tensors(self, device, dtype): 1603*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn(2, 1, 3, dtype=dtype, device=device) 1604*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(3, dtype=dtype, device=device) 1605*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn(3, 1, dtype=dtype, device=device) 1606*da0073e9SAndroid Build Coastguard Worker expected_size = (2, 3, 3) 1607*da0073e9SAndroid Build Coastguard Worker 1608*da0073e9SAndroid Build Coastguard Worker y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2) 1609*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y0.size() == expected_size) 1610*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y1.size() == expected_size) 1611*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y2.size() == expected_size) 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1614*da0073e9SAndroid Build Coastguard Worker def test_broadcast_shapes(self, device): 1615*da0073e9SAndroid Build Coastguard Worker examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)] 1616*da0073e9SAndroid Build Coastguard Worker for s0 in examples: 1617*da0073e9SAndroid Build Coastguard Worker x0 = torch.randn(s0) 1618*da0073e9SAndroid Build Coastguard Worker expected = torch.broadcast_tensors(x0)[0].shape 1619*da0073e9SAndroid Build Coastguard Worker actual = torch.broadcast_shapes(s0) 1620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker for s1 in examples: 1623*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(s1) 1624*da0073e9SAndroid Build Coastguard Worker expected = torch.broadcast_tensors(x0, x1)[0].shape 1625*da0073e9SAndroid Build Coastguard Worker actual = torch.broadcast_shapes(s0, s1) 1626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 1627*da0073e9SAndroid Build Coastguard Worker 1628*da0073e9SAndroid Build Coastguard Worker inputs_list = [[1, 4], [4, 1], [1, 1, 3]] 1629*da0073e9SAndroid Build Coastguard Worker for integral_inputs in inputs_list: 1630*da0073e9SAndroid Build Coastguard Worker res1 = torch.broadcast_shapes(*integral_inputs) 1631*da0073e9SAndroid Build Coastguard Worker res2 = torch.broadcast_tensors(*map(torch.empty, integral_inputs))[0].shape 1632*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]] 1635*da0073e9SAndroid Build Coastguard Worker for integral_inputs_with_neg_vals in inputs_with_neg_vals: 1636*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1637*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Trying to create tensor with negative dimension" 1638*da0073e9SAndroid Build Coastguard Worker ): 1639*da0073e9SAndroid Build Coastguard Worker torch.broadcast_shapes(*integral_inputs_with_neg_vals) 1640*da0073e9SAndroid Build Coastguard Worker 1641*da0073e9SAndroid Build Coastguard Worker integral_inputs_error_case = [(3, 5), (2, 4, 1)] 1642*da0073e9SAndroid Build Coastguard Worker for error_input in integral_inputs_error_case: 1643*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1644*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1645*da0073e9SAndroid Build Coastguard Worker "Shape mismatch: objects cannot be broadcast to a single shape", 1646*da0073e9SAndroid Build Coastguard Worker ): 1647*da0073e9SAndroid Build Coastguard Worker torch.broadcast_shapes(*error_input) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)] 1650*da0073e9SAndroid Build Coastguard Worker for s0 in negative_inputs: 1651*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1652*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Trying to create tensor with negative dimension" 1653*da0073e9SAndroid Build Coastguard Worker ): 1654*da0073e9SAndroid Build Coastguard Worker torch.broadcast_shapes(s0) 1655*da0073e9SAndroid Build Coastguard Worker 1656*da0073e9SAndroid Build Coastguard Worker for s1 in negative_inputs: 1657*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1658*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Trying to create tensor with negative dimension" 1659*da0073e9SAndroid Build Coastguard Worker ): 1660*da0073e9SAndroid Build Coastguard Worker torch.broadcast_shapes(s0, s1) 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker float_inputs_error_case = [(1.1, 2.0), (1.1, 1.0)] 1663*da0073e9SAndroid Build Coastguard Worker for error_case in float_inputs_error_case: 1664*da0073e9SAndroid Build Coastguard Worker for float_input in error_case: 1665*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1666*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1667*da0073e9SAndroid Build Coastguard Worker "Input shapes " 1668*da0073e9SAndroid Build Coastguard Worker "should be of type ints, a tuple of ints, or a list of ints", 1669*da0073e9SAndroid Build Coastguard Worker ): 1670*da0073e9SAndroid Build Coastguard Worker torch.broadcast_shapes(float_input) 1671*da0073e9SAndroid Build Coastguard Worker 1672*da0073e9SAndroid Build Coastguard Worker diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))] 1673*da0073e9SAndroid Build Coastguard Worker for s0 in diff_input_types: 1674*da0073e9SAndroid Build Coastguard Worker res1 = torch.broadcast_shapes(*s0) 1675*da0073e9SAndroid Build Coastguard Worker res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape 1676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker # Skip BFloat16 since numpy does not support it 1679*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 1680*da0073e9SAndroid Build Coastguard Worker def test_broadcast_to(self, device, dtype): 1681*da0073e9SAndroid Build Coastguard Worker def can_broadcast(s0, s1): 1682*da0073e9SAndroid Build Coastguard Worker # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension 1683*da0073e9SAndroid Build Coastguard Worker s0 = tuple(reversed(s0)) 1684*da0073e9SAndroid Build Coastguard Worker s1 = tuple(reversed(s1)) 1685*da0073e9SAndroid Build Coastguard Worker for i in range(len(s0)): 1686*da0073e9SAndroid Build Coastguard Worker if s0[i] != 1 and s0[i] != s1[i]: 1687*da0073e9SAndroid Build Coastguard Worker return False 1688*da0073e9SAndroid Build Coastguard Worker return True 1689*da0073e9SAndroid Build Coastguard Worker 1690*da0073e9SAndroid Build Coastguard Worker sizes = ((), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)) 1691*da0073e9SAndroid Build Coastguard Worker for s0, s1 in combinations(sizes, r=2): 1692*da0073e9SAndroid Build Coastguard Worker t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9) 1693*da0073e9SAndroid Build Coastguard Worker t_np = t.cpu().numpy() 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker if can_broadcast(s0, s1): 1696*da0073e9SAndroid Build Coastguard Worker res = torch.broadcast_to(t, s1) 1697*da0073e9SAndroid Build Coastguard Worker np_res = np.broadcast_to(t_np, s1) 1698*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, np_res) 1699*da0073e9SAndroid Build Coastguard Worker else: 1700*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1701*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1702*da0073e9SAndroid Build Coastguard Worker r"The expanded size of the tensor \(\d\) " 1703*da0073e9SAndroid Build Coastguard Worker r"must match the existing size \(\d\)", 1704*da0073e9SAndroid Build Coastguard Worker ): 1705*da0073e9SAndroid Build Coastguard Worker torch.broadcast_to(t, s1) 1706*da0073e9SAndroid Build Coastguard Worker 1707*da0073e9SAndroid Build Coastguard Worker def test_view(self, device): 1708*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(15, device=device) 1709*da0073e9SAndroid Build Coastguard Worker template = torch.rand(3, 5, device=device) 1710*da0073e9SAndroid Build Coastguard Worker empty = torch.empty(0, device=device) 1711*da0073e9SAndroid Build Coastguard Worker target = template.size() 1712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view_as(template).size(), target) 1713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(3, 5).size(), target) 1714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) 1715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(-1, 5).size(), target) 1716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(3, -1).size(), target) 1717*da0073e9SAndroid Build Coastguard Worker tensor_view = tensor.view(5, 3) 1718*da0073e9SAndroid Build Coastguard Worker tensor_view.fill_(random.uniform(0, 1)) 1719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view_as(empty), empty) 1720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(0), empty) 1721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) 1722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Worker # test size inference with empty tensors 1725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(-1).size(), torch.Size([0])) 1726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1729*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"because the unspecified dimension size -1 can be any value" 1730*da0073e9SAndroid Build Coastguard Worker ): 1731*da0073e9SAndroid Build Coastguard Worker empty.view(-1, 0) 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1734*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"because the unspecified dimension size -1 can be any value" 1735*da0073e9SAndroid Build Coastguard Worker ): 1736*da0073e9SAndroid Build Coastguard Worker empty.view(3, 0, -1, 0) 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) 1739*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) 1740*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker # test view when tensor is not contiguous in every dimension, but only 1743*da0073e9SAndroid Build Coastguard Worker # contiguous dimensions are touched. 1744*da0073e9SAndroid Build Coastguard Worker tensor = ( 1745*da0073e9SAndroid Build Coastguard Worker torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device) 1746*da0073e9SAndroid Build Coastguard Worker .transpose(-1, 2) 1747*da0073e9SAndroid Build Coastguard Worker .transpose(-2, 3) 1748*da0073e9SAndroid Build Coastguard Worker ) 1749*da0073e9SAndroid Build Coastguard Worker # size: [ 4, 2, 3, 9, 6, 2, 1, 5] 1750*da0073e9SAndroid Build Coastguard Worker # stride: [3840, 1620, 1, 3, 54, 27, 324, 324] 1751*da0073e9SAndroid Build Coastguard Worker # contiguous dim chunks: [__________, ____, ____, __________, ____, ____] 1752*da0073e9SAndroid Build Coastguard Worker # merging 1 to chunk after: [__________, ____, ____, __________, __________] 1753*da0073e9SAndroid Build Coastguard Worker contig_tensor = tensor.clone() 1754*da0073e9SAndroid Build Coastguard Worker # [4, 2] => [8, 1] 1755*da0073e9SAndroid Build Coastguard Worker # [3] => [3] 1756*da0073e9SAndroid Build Coastguard Worker # [9] => [3, 3] 1757*da0073e9SAndroid Build Coastguard Worker # [6, 2] => [4, 1, 3] 1758*da0073e9SAndroid Build Coastguard Worker # [1, 5] => [5] 1759*da0073e9SAndroid Build Coastguard Worker view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5] 1760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) 1761*da0073e9SAndroid Build Coastguard Worker # [4, 2] => [2, 4] 1762*da0073e9SAndroid Build Coastguard Worker # [3] => [3] 1763*da0073e9SAndroid Build Coastguard Worker # [9] => [1, 9] 1764*da0073e9SAndroid Build Coastguard Worker # [6, 2] => [2, 2, 3] 1765*da0073e9SAndroid Build Coastguard Worker # [1, 5] => [5, 1] 1766*da0073e9SAndroid Build Coastguard Worker view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1] 1767*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) 1768*da0073e9SAndroid Build Coastguard Worker # adding size 1 dims 1769*da0073e9SAndroid Build Coastguard Worker view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] 1770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker # invalid views 1773*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(-1)) 1774*da0073e9SAndroid Build Coastguard Worker # crossing [4, 2], [3] 1775*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5)) 1776*da0073e9SAndroid Build Coastguard Worker # crossing [6, 2], [1, 5] 1777*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10)) 1778*da0073e9SAndroid Build Coastguard Worker # crossing [9], [6, 2] 1779*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5)) 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker # view with stride 0 dims 1782*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(1, 1, device=device).expand( 1783*da0073e9SAndroid Build Coastguard Worker 3, 4 1784*da0073e9SAndroid Build Coastguard Worker ) # all dims are contiguous 1785*da0073e9SAndroid Build Coastguard Worker contig_tensor = tensor.clone() 1786*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(-1), contig_tensor.view(-1)) 1787*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1)) 1788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1)) 1789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) 1790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) 1791*da0073e9SAndroid Build Coastguard Worker 1792*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 1793*da0073e9SAndroid Build Coastguard Worker def test_reshape_view_semantics(self, device, dtype): 1794*da0073e9SAndroid Build Coastguard Worker tensor = make_tensor((15, 4), dtype=dtype, device=device) 1795*da0073e9SAndroid Build Coastguard Worker target = (20, 3) 1796*da0073e9SAndroid Build Coastguard Worker 1797*da0073e9SAndroid Build Coastguard Worker # Cases where the tensor can be returned as a view. 1798*da0073e9SAndroid Build Coastguard Worker view_tensor = tensor.reshape(target) 1799*da0073e9SAndroid Build Coastguard Worker self.assertEqual((view_tensor.size()), target) 1800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr()) 1801*da0073e9SAndroid Build Coastguard Worker 1802*da0073e9SAndroid Build Coastguard Worker # Cases where the tensor must be copied (transpose makes it non-contiguous forcing 1803*da0073e9SAndroid Build Coastguard Worker # the copy). 1804*da0073e9SAndroid Build Coastguard Worker copy_tensor = tensor.transpose(0, 1).reshape(target) 1805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(copy_tensor.size(), target) 1806*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 1807*da0073e9SAndroid Build Coastguard Worker tensor.storage().data_ptr(), copy_tensor.storage().data_ptr() 1808*da0073e9SAndroid Build Coastguard Worker ) 1809*da0073e9SAndroid Build Coastguard Worker 1810*da0073e9SAndroid Build Coastguard Worker def test_contiguous(self, device): 1811*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 16, 5, 5, device=device) 1812*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 1813*da0073e9SAndroid Build Coastguard Worker stride = list(x.stride()) 1814*da0073e9SAndroid Build Coastguard Worker stride[0] = 20 1815*da0073e9SAndroid Build Coastguard Worker # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 1816*da0073e9SAndroid Build Coastguard Worker x.set_(x.storage(), 0, x.size(), stride) 1817*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1820*da0073e9SAndroid Build Coastguard Worker # Skip BFloat16 since numpy does not support it 1821*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 1822*da0073e9SAndroid Build Coastguard Worker def test_tensor_split_sections(self, device, dtype): 1823*da0073e9SAndroid Build Coastguard Worker input_sizes = [ 1824*da0073e9SAndroid Build Coastguard Worker (0,), 1825*da0073e9SAndroid Build Coastguard Worker (10,), 1826*da0073e9SAndroid Build Coastguard Worker (10, 0), 1827*da0073e9SAndroid Build Coastguard Worker (0, 10), 1828*da0073e9SAndroid Build Coastguard Worker (4, 10), 1829*da0073e9SAndroid Build Coastguard Worker (12, 3), 1830*da0073e9SAndroid Build Coastguard Worker ] 1831*da0073e9SAndroid Build Coastguard Worker for input_size in input_sizes: 1832*da0073e9SAndroid Build Coastguard Worker a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1833*da0073e9SAndroid Build Coastguard Worker # Run tests on transposed input if it has at least 2 dims 1834*da0073e9SAndroid Build Coastguard Worker for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: 1835*da0073e9SAndroid Build Coastguard Worker a_n = a.cpu().numpy() 1836*da0073e9SAndroid Build Coastguard Worker for dim in range(-a.dim(), a.dim()): 1837*da0073e9SAndroid Build Coastguard Worker for sections in range(1, 2 * a.size(dim)): 1838*da0073e9SAndroid Build Coastguard Worker msg = f"input_size {input_size}, sections {sections}, dim {dim}" 1839*da0073e9SAndroid Build Coastguard Worker result1 = torch.tensor_split(a, sections, dim) 1840*da0073e9SAndroid Build Coastguard Worker result2 = torch.tensor_split( 1841*da0073e9SAndroid Build Coastguard Worker a, torch.tensor(sections, dtype=torch.int64), dim 1842*da0073e9SAndroid Build Coastguard Worker ) 1843*da0073e9SAndroid Build Coastguard Worker for r1, r2 in zip(result1, result2): 1844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.device, torch.device(device), msg=msg) 1845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.dtype, dtype, msg=msg) 1846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2.device, torch.device(device), msg=msg) 1847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2.dtype, dtype, msg=msg) 1848*da0073e9SAndroid Build Coastguard Worker result_n = np.array_split(a_n, sections, dim) 1849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_n, result1, msg=msg) 1850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_n, result2, msg=msg) 1851*da0073e9SAndroid Build Coastguard Worker 1852*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1853*da0073e9SAndroid Build Coastguard Worker # Skip BFloat16 since numpy does not support it 1854*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 1855*da0073e9SAndroid Build Coastguard Worker def test_tensor_split_indices(self, device, dtype): 1856*da0073e9SAndroid Build Coastguard Worker input_sizes = [ 1857*da0073e9SAndroid Build Coastguard Worker (0,), 1858*da0073e9SAndroid Build Coastguard Worker (10,), 1859*da0073e9SAndroid Build Coastguard Worker (10, 0), 1860*da0073e9SAndroid Build Coastguard Worker (0, 10), 1861*da0073e9SAndroid Build Coastguard Worker (4, 10), 1862*da0073e9SAndroid Build Coastguard Worker (12, 3), 1863*da0073e9SAndroid Build Coastguard Worker ] 1864*da0073e9SAndroid Build Coastguard Worker indices_args = [ 1865*da0073e9SAndroid Build Coastguard Worker (), 1866*da0073e9SAndroid Build Coastguard Worker (0,), 1867*da0073e9SAndroid Build Coastguard Worker (3,), 1868*da0073e9SAndroid Build Coastguard Worker (10,), 1869*da0073e9SAndroid Build Coastguard Worker (-1,), 1870*da0073e9SAndroid Build Coastguard Worker (-10,), 1871*da0073e9SAndroid Build Coastguard Worker (2, -1), 1872*da0073e9SAndroid Build Coastguard Worker (3, 4, 10), 1873*da0073e9SAndroid Build Coastguard Worker (0, -1, 0, 10), 1874*da0073e9SAndroid Build Coastguard Worker (1, 5, 2, 8), 1875*da0073e9SAndroid Build Coastguard Worker ] 1876*da0073e9SAndroid Build Coastguard Worker for input_size in input_sizes: 1877*da0073e9SAndroid Build Coastguard Worker a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1878*da0073e9SAndroid Build Coastguard Worker # Run tests on transposed input if it has at least 2 dims 1879*da0073e9SAndroid Build Coastguard Worker for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: 1880*da0073e9SAndroid Build Coastguard Worker a_n = a.cpu().numpy() 1881*da0073e9SAndroid Build Coastguard Worker for dim in range(-a.dim(), a.dim()): 1882*da0073e9SAndroid Build Coastguard Worker for indices in indices_args: 1883*da0073e9SAndroid Build Coastguard Worker result_1 = torch.tensor_split(a, indices, dim) 1884*da0073e9SAndroid Build Coastguard Worker result_2 = torch.tensor_split( 1885*da0073e9SAndroid Build Coastguard Worker a, torch.tensor(indices, dtype=torch.int64), dim 1886*da0073e9SAndroid Build Coastguard Worker ) 1887*da0073e9SAndroid Build Coastguard Worker 1888*da0073e9SAndroid Build Coastguard Worker msg = f"input_size {input_size}, indices {indices}, dim {dim}" 1889*da0073e9SAndroid Build Coastguard Worker for r1, r2 in zip(result_1, result_2): 1890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.device, torch.device(device), msg=msg) 1891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.dtype, dtype, msg=msg) 1892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2.device, torch.device(device), msg=msg) 1893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r2.dtype, dtype, msg=msg) 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker result_n = np.array_split(a_n, indices, dim) 1896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_n, result_1, msg=msg) 1897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_n, result_2, msg=msg) 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1900*da0073e9SAndroid Build Coastguard Worker def test_tensor_split_errors(self, device): 1901*da0073e9SAndroid Build Coastguard Worker S = 10 1902*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1903*da0073e9SAndroid Build Coastguard Worker # input size, sections or indices, dim, error type, error message, numpy error type 1904*da0073e9SAndroid Build Coastguard Worker [(S,), 10, 1, IndexError, r"Dimension out of range", IndexError], 1905*da0073e9SAndroid Build Coastguard Worker [ 1906*da0073e9SAndroid Build Coastguard Worker (), 1907*da0073e9SAndroid Build Coastguard Worker 10, 1908*da0073e9SAndroid Build Coastguard Worker 0, 1909*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1910*da0073e9SAndroid Build Coastguard Worker r"tensor_split expected at least a 1-dimensional tensor, " 1911*da0073e9SAndroid Build Coastguard Worker + "but got a tensor with 0 dims", 1912*da0073e9SAndroid Build Coastguard Worker IndexError, 1913*da0073e9SAndroid Build Coastguard Worker ], 1914*da0073e9SAndroid Build Coastguard Worker [(S,), (10,), 1, IndexError, r"Dimension out of range", IndexError], 1915*da0073e9SAndroid Build Coastguard Worker [ 1916*da0073e9SAndroid Build Coastguard Worker (), 1917*da0073e9SAndroid Build Coastguard Worker (10,), 1918*da0073e9SAndroid Build Coastguard Worker 0, 1919*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1920*da0073e9SAndroid Build Coastguard Worker r"tensor_split expected at least a 1-dimensional tensor, " 1921*da0073e9SAndroid Build Coastguard Worker + "but got a tensor with 0 dims", 1922*da0073e9SAndroid Build Coastguard Worker IndexError, 1923*da0073e9SAndroid Build Coastguard Worker ], 1924*da0073e9SAndroid Build Coastguard Worker [ 1925*da0073e9SAndroid Build Coastguard Worker (S,), 1926*da0073e9SAndroid Build Coastguard Worker 0, 1927*da0073e9SAndroid Build Coastguard Worker 0, 1928*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1929*da0073e9SAndroid Build Coastguard Worker r"number of sections must be larger than 0, got 0", 1930*da0073e9SAndroid Build Coastguard Worker ValueError, 1931*da0073e9SAndroid Build Coastguard Worker ], 1932*da0073e9SAndroid Build Coastguard Worker [ 1933*da0073e9SAndroid Build Coastguard Worker (S,), 1934*da0073e9SAndroid Build Coastguard Worker -1, 1935*da0073e9SAndroid Build Coastguard Worker 0, 1936*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1937*da0073e9SAndroid Build Coastguard Worker r"number of sections must be larger than 0, got -1", 1938*da0073e9SAndroid Build Coastguard Worker ValueError, 1939*da0073e9SAndroid Build Coastguard Worker ], 1940*da0073e9SAndroid Build Coastguard Worker ] 1941*da0073e9SAndroid Build Coastguard Worker for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases: 1942*da0073e9SAndroid Build Coastguard Worker a = torch.randn(input_size, device=device) 1943*da0073e9SAndroid Build Coastguard Worker msg = f"input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}" 1944*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(err, err_msg, msg=msg): 1945*da0073e9SAndroid Build Coastguard Worker torch.tensor_split(a, sections_or_indices, dim) 1946*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(err, err_msg, msg=msg): 1947*da0073e9SAndroid Build Coastguard Worker torch.tensor_split(a, torch.tensor(sections_or_indices), dim) 1948*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(numpy_err, msg=msg): 1949*da0073e9SAndroid Build Coastguard Worker np.array_split(a.cpu().numpy(), sections_or_indices, dim) 1950*da0073e9SAndroid Build Coastguard Worker 1951*da0073e9SAndroid Build Coastguard Worker # addtional tests for tensor_split with tensor_indices_or_sections 1952*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1953*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1954*da0073e9SAndroid Build Coastguard Worker r"tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float", 1955*da0073e9SAndroid Build Coastguard Worker ): 1956*da0073e9SAndroid Build Coastguard Worker torch.tensor_split(a, torch.tensor(1.1), dim) 1957*da0073e9SAndroid Build Coastguard Worker 1958*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1959*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1960*da0073e9SAndroid Build Coastguard Worker r"tensor_split expected tensor_indices_or_sections to be a" 1961*da0073e9SAndroid Build Coastguard Worker + " zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims", 1962*da0073e9SAndroid Build Coastguard Worker ): 1963*da0073e9SAndroid Build Coastguard Worker torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0) 1964*da0073e9SAndroid Build Coastguard Worker 1965*da0073e9SAndroid Build Coastguard Worker def test_resize_all_dtypes_and_devices(self, device): 1966*da0073e9SAndroid Build Coastguard Worker shape = (2, 2) 1967*da0073e9SAndroid Build Coastguard Worker for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 1968*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 1969*da0073e9SAndroid Build Coastguard Worker x.resize_(shape) 1970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, x.shape) 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard Worker def test_resize_as_all_dtypes_and_devices(self, device): 1973*da0073e9SAndroid Build Coastguard Worker for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 1974*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 1975*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) 1976*da0073e9SAndroid Build Coastguard Worker x.resize_as_(y) 1977*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.shape, x.shape) 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1980*da0073e9SAndroid Build Coastguard Worker def test_resize_overflow(self, device): 1981*da0073e9SAndroid Build Coastguard Worker x = torch.empty((), dtype=torch.float64) 1982*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1983*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Storage size calculation overflowed" 1984*da0073e9SAndroid Build Coastguard Worker ): 1985*da0073e9SAndroid Build Coastguard Worker x.resize_([2, 4, 2**29, 2**29]) 1986*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "overflow"): 1987*da0073e9SAndroid Build Coastguard Worker x.resize_([8, 8, 2**29, 2**29]) 1988*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"): 1989*da0073e9SAndroid Build Coastguard Worker x.resize_([0, 4, 2305843009213693952]) 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker def test_view_all_dtypes_and_devices(self, device): 1992*da0073e9SAndroid Build Coastguard Worker for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 1993*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 1994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.view(6).shape, [6]) 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("conj bit not implemented in TensorVariable yet") 1997*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1998*da0073e9SAndroid Build Coastguard Worker def test_conj_neg_view_numpy_error(self, device): 1999*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2000*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2001*da0073e9SAndroid Build Coastguard Worker "has conjugate bit set", 2002*da0073e9SAndroid Build Coastguard Worker lambda: torch.tensor([1 + 2j]).conj().numpy(), 2003*da0073e9SAndroid Build Coastguard Worker ) 2004*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2005*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2006*da0073e9SAndroid Build Coastguard Worker "has negative bit set", 2007*da0073e9SAndroid Build Coastguard Worker lambda: torch.tensor([1 + 2j]).conj().imag.numpy(), 2008*da0073e9SAndroid Build Coastguard Worker ) 2009*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2010*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2011*da0073e9SAndroid Build Coastguard Worker "not supported for conjugate view tensors", 2012*da0073e9SAndroid Build Coastguard Worker lambda: torch.tensor([1 + 2j]).conj().view(torch.float64), 2013*da0073e9SAndroid Build Coastguard Worker ) 2014*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2015*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2016*da0073e9SAndroid Build Coastguard Worker "not supported for tensors with negative bit set", 2017*da0073e9SAndroid Build Coastguard Worker lambda: torch.tensor([1 + 2j]).conj().imag.view(torch.int32), 2018*da0073e9SAndroid Build Coastguard Worker ) 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker @onlyCPU 2021*da0073e9SAndroid Build Coastguard Worker def test_crow_col_indices(self, device): 2022*da0073e9SAndroid Build Coastguard Worker crow_indices = (0, 1, 2) 2023*da0073e9SAndroid Build Coastguard Worker col_indices = (1, 0) 2024*da0073e9SAndroid Build Coastguard Worker values = (1, 2) 2025*da0073e9SAndroid Build Coastguard Worker t = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2)) 2026*da0073e9SAndroid Build Coastguard Worker # This is the test. If crow_indices is not a view op it'll 2027*da0073e9SAndroid Build Coastguard Worker # trigger an internal assert due to use count greater than 1 2028*da0073e9SAndroid Build Coastguard Worker # in debug build. 2029*da0073e9SAndroid Build Coastguard Worker t.crow_indices() 2030*da0073e9SAndroid Build Coastguard Worker t.col_indices() 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker 2033*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestViewOps, globals(), include_lazy=True) 2034*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOldViewOps, globals()) 2035*da0073e9SAndroid Build Coastguard Worker 2036*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2037*da0073e9SAndroid Build Coastguard Worker run_tests() 2038