1# Owner(s): ["module: named tensor"] 2 3import unittest 4from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY 5from torch.testing._internal.common_utils import skipIfTorchDynamo 6from torch.testing._internal.common_cuda import TEST_CUDA 7from torch.testing._internal.common_device_type import get_all_device_types 8from collections import namedtuple, OrderedDict 9import itertools 10import functools 11import torch 12from torch import Tensor 13import torch.nn.functional as F 14from multiprocessing.reduction import ForkingPickler 15import pickle 16import io 17import sys 18import warnings 19 20 21def pass_name_to_python_arg_parser(name): 22 x = torch.empty(2, names=(name,)) 23 24 25def flatten(lst): 26 return [item for sublist in lst for item in sublist] 27 28 29Function = namedtuple('TestCase', ['name', 'lambd']) 30 31 32def parse_compressed_namedshape(string): 33 # This is a metalanguage for describing a shape of a tensor compactly. 34 # 'N:3,C:2' -> size = [3, 2], names: ['N', 'C'] 35 # 'None:3,None:2' -> size = [3, 2], names: ['None', 'None'] 36 # '3,2' -> size = [3, 2], names=None passed to ctor. 37 def parse_name(maybe_name): 38 maybe_name = maybe_name.strip() 39 if maybe_name == 'None': 40 return None 41 return maybe_name 42 43 string = string.strip() 44 45 # '' -> size: [], names:None 46 if len(string) == 0: 47 return None, [] 48 49 # '3, 2' -> size = [3, 2], None names. 50 if ':' not in string: 51 return None, [int(size) for size in string.split(',')] 52 53 dims = string.split(',') 54 tuples = [dim.split(':') for dim in dims] 55 return zip(*[(parse_name(name), int(size)) for name, size in tuples]) 56 57 58def create(namedshape, factory=torch.randn): 59 # namedshape: str 60 names, shape = parse_compressed_namedshape(namedshape) 61 return factory(shape, names=names) 62 63 64def out_fn(operator): 65 @functools.wraps(operator) 66 def fn(*inputs): 67 return operator(*inputs[1:], out=inputs[0]) 68 return fn 69 70 71class TestNamedTensor(TestCase): 72 def test_aaa_must_run_first_check_experimental_warning(self): 73 # TODO(rzou): It would be nice for this to be a "real" python warning. 74 # Right now this error message only prints once and doesn't respect 75 # warnings.simplefilter behavior (where python users can control whether 76 # or not to display warnings once, all the time, or never). 77 with warnings.catch_warnings(record=True) as warns: 78 x = torch.randn(3, 3, names=('N', 'C')) 79 self.assertEqual(len(warns), 1) 80 self.assertTrue(str(warns[0].message).startswith( 81 'Named tensors and all their associated APIs are an experimental feature')) 82 83 def test_trivial(self): 84 pass 85 86 def _test_name_inference(self, op, args=(), expected_names=(), device='cpu', 87 maybe_raises_regex=None): 88 casted_args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg 89 for arg in args] 90 if maybe_raises_regex is not None: 91 with self.assertRaisesRegex(RuntimeError, maybe_raises_regex): 92 result = op(*args) 93 return 94 result = op(*args) 95 self.assertEqual(result.names, expected_names, 96 msg=f'Name inference for {op.__name__} on device {device} failed') 97 98 # TODO(rzou): Some form of this check should be added to self.assertEqual. 99 # Right now I don't know what it should look like. 100 def assertTensorDataAndNamesEqual(self, x, y): 101 self.assertEqual(x.names, y.names) 102 unnamed_x = x.rename(None) 103 unnamed_y = y.rename(None) 104 self.assertEqual(unnamed_x, unnamed_y) 105 106 def _test_factory(self, factory, device): 107 x = factory([], device=device) 108 self.assertEqual(x.names, ()) 109 110 x = factory(1, 2, 3, device=device) 111 self.assertEqual(x.names, (None, None, None)) 112 113 x = factory(1, 2, 3, names=None, device=device) 114 self.assertEqual(x.names, (None, None, None)) 115 116 x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device) 117 self.assertEqual(x.names, ('N', 'T', 'D')) 118 119 x = factory(1, 2, 3, names=('N', None, 'D'), device=device) 120 self.assertEqual(x.names, ('N', None, 'D')) 121 122 x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device) 123 self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5')) 124 125 with self.assertRaisesRegex(RuntimeError, 126 'a valid identifier contains only'): 127 x = factory(2, names=('1',), device=device) 128 129 with self.assertRaisesRegex(RuntimeError, 130 'a valid identifier contains only'): 131 x = factory(2, names=('?',), device=device) 132 133 with self.assertRaisesRegex(RuntimeError, 'Number of names'): 134 x = factory(2, 1, names=('N',), device=device) 135 136 with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'): 137 x = factory(2, 1, names='N', device=device) 138 139 with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'): 140 x = factory(2, 1, 1, names=('N', 'C', 'N'), device=device) 141 142 names64 = ['A' * i for i in range(1, 65)] 143 x = factory([1] * 64, names=names64, device=device) 144 self.assertEqual(x.names, names64) 145 146 with self.assertRaisesRegex( 147 RuntimeError, 148 'only support up to 64 dims'): 149 names65 = ['A' * i for i in range(1, 66)] 150 x = factory([1] * 65, names=names64, device=device) 151 152 @skipIfTorchDynamo("not a bug: Dynamo causes the refcounts to be different") 153 def test_none_names_refcount(self, N=10): 154 def scope(): 155 unnamed = torch.empty(2, 3) 156 unnamed.names # materialize [None, None] 157 158 prev_none_refcnt = sys.getrefcount(None) 159 # Ran it N times to reduce flakiness 160 [scope() for i in range(N)] 161 after_none_refcnt = sys.getrefcount(None) 162 self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2, 163 msg='Using tensor.names should not change ' 164 'the refcount of Py_None') 165 166 def test_has_names(self): 167 unnamed = torch.empty(2, 3) 168 none_named = torch.empty(2, 3, names=(None, None)) 169 partially_named = torch.empty(2, 3, names=('N', None)) 170 fully_named = torch.empty(2, 3, names=('N', 'C')) 171 172 self.assertFalse(unnamed.has_names()) 173 self.assertFalse(none_named.has_names()) 174 self.assertTrue(partially_named.has_names()) 175 self.assertTrue(fully_named.has_names()) 176 177 def test_py3_ellipsis(self): 178 tensor = torch.randn(2, 3, 5, 7) 179 output = tensor.refine_names('N', ..., 'C') 180 self.assertEqual(output.names, ['N', None, None, 'C']) 181 182 def test_refine_names(self): 183 # Unnamed tensor -> Unnamed tensor 184 self._test_name_inference(Tensor.refine_names, 185 [create('None:1,None:2,None:3'), 'N', 'C', 'H'], 186 ['N', 'C', 'H']) 187 188 # Named tensor -> Named tensor 189 self._test_name_inference(Tensor.refine_names, 190 [create('N:1,C:2,H:3'), 'N', 'C', 'H'], 191 ['N', 'C', 'H']) 192 193 # Partially named tensor -> named tensor 194 self._test_name_inference(Tensor.refine_names, 195 [create('None:1,C:2,None:3'), None, 'C', 'H'], 196 [None, 'C', 'H']) 197 198 # Too few names 199 self._test_name_inference(Tensor.refine_names, 200 [create('None:2,None:3'), 'N', 'C', 'H'], 201 maybe_raises_regex="different number of dims") 202 203 # Cannot change Tensor[D] to Tensor[N] 204 self._test_name_inference(Tensor.refine_names, 205 [create('D:3'), 'N'], 206 maybe_raises_regex="is different from") 207 208 # Cannot change Tensor[D] to Tensor[None] 209 self._test_name_inference(Tensor.refine_names, 210 [create('D:3'), None], 211 maybe_raises_regex="'D' is more specific than None") 212 213 # globbing behavior exists 214 self._test_name_inference(Tensor.refine_names, 215 [create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'], 216 [None, None, 'C', 'H']) 217 218 def test_detach(self): 219 names = ['N'] 220 self._test_name_inference( 221 Tensor.detach_, 222 [torch.randn(3, requires_grad=True, names=names)], 223 names) 224 self._test_name_inference( 225 Tensor.detach, 226 [torch.randn(3, requires_grad=True, names=names)], 227 names) 228 229 def test_index_fill(self): 230 for device in get_all_device_types(): 231 expected_names = ('N', 'C') 232 x = torch.randn(3, 5, device=device, names=expected_names) 233 234 output = x.index_fill_('C', torch.tensor([0, 1], device=device), 5) 235 self.assertEqual(output.names, expected_names) 236 237 output = x.index_fill_('C', torch.tensor([0, 1], device=device), torch.tensor(4.)) 238 self.assertEqual(output.names, expected_names) 239 240 output = x.index_fill('C', torch.tensor([0, 1], device=device), 5) 241 self.assertEqual(output.names, expected_names) 242 243 output = x.index_fill('C', torch.tensor([0, 1], device=device), torch.tensor(4.)) 244 self.assertEqual(output.names, expected_names) 245 246 def test_equal(self): 247 for device in get_all_device_types(): 248 tensor = torch.randn(2, 3, device=device) 249 other = tensor.clone() 250 251 self.assertTrue(torch.equal(tensor.rename('N', 'C'), other.rename('N', 'C'))) 252 self.assertFalse(torch.equal(tensor.rename('M', 'C'), other.rename('N', 'C'))) 253 self.assertFalse(torch.equal(tensor.rename(None, 'C'), other.rename('N', 'C'))) 254 255 def test_squeeze(self): 256 x = create('N:3,C:1,H:1,W:1') 257 output = x.squeeze('C') 258 self.assertEqual(output.names, ['N', 'H', 'W']) 259 260 output = x.squeeze() 261 self.assertEqual(output.names, ['N']) 262 263 def test_repr(self): 264 named_tensor = torch.zeros(2, 3).rename_('N', 'C') 265 expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]], names=('N', 'C'))" 266 self.assertEqual(repr(named_tensor), expected) 267 268 unnamed_tensor = torch.zeros(2, 3) 269 expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]])" 270 self.assertEqual(repr(unnamed_tensor), expected) 271 272 none_named_tensor = torch.zeros(2, 3).rename_(None, None) 273 self.assertEqual(repr(none_named_tensor), expected) 274 275 def test_diagonal(self): 276 named_tensor = torch.zeros(2, 3, 5, 7, names=list('ABCD')) 277 self.assertEqual(named_tensor.diagonal().names, ['C', 'D', None]) 278 self.assertEqual(named_tensor.diagonal(1, 3).names, ['A', 'C', None]) 279 280 self.assertEqual(named_tensor.diagonal(outdim='E', dim1='B', dim2='D').names, 281 ['A', 'C', 'E']) 282 283 def test_max_pooling(self): 284 def check_tuple_return(op, inputs, expected_names): 285 values, indices = op(*inputs) 286 self.assertEqual(values.names, expected_names) 287 self.assertEqual(indices.names, expected_names) 288 289 for device in get_all_device_types(): 290 291 named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list('ABC')) 292 named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list('ABCD')) 293 named_tensor_3d = torch.zeros(2, 3, 5, 7, 9, device=device, names=list('ABCDE')) 294 295 self.assertEqual(F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names) 296 self.assertEqual(F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names) 297 self.assertEqual(F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names) 298 299 check_tuple_return(F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names) 300 check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names) 301 check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names) 302 303 def test_max_pooling_without_names_does_not_warn(self): 304 for device in get_all_device_types(): 305 tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True) 306 with warnings.catch_warnings(record=True) as warns: 307 warnings.simplefilter("always") 308 result = F.max_pool2d(tensor_2d, [2, 2]) 309 result.sum().backward() 310 self.assertEqual(len(warns), 0) 311 312 def test_no_save_support(self): 313 named_tensor = torch.zeros(2, 3, names=('N', 'C')) 314 buf = io.BytesIO() 315 with self.assertRaisesRegex(RuntimeError, "NYI"): 316 torch.save(named_tensor, buf) 317 318 def test_no_pickle_support(self): 319 named_tensor = torch.zeros(2, 3, names=('N', 'C')) 320 with self.assertRaisesRegex(RuntimeError, "NYI"): 321 serialized = pickle.dumps(named_tensor) 322 323 def test_no_multiprocessing_support(self): 324 named_tensor = torch.zeros(2, 3, names=('N', 'C')) 325 buf = io.BytesIO() 326 with self.assertRaisesRegex(RuntimeError, "NYI"): 327 ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor) 328 329 def test_big_tensor_repr_has_names(self): 330 def check_repr(named_tensor): 331 unnamed_tensor = named_tensor.rename(None) 332 names_tag = f'names={named_tensor.names}' 333 self.assertIn(names_tag, repr(named_tensor)) 334 335 check_repr(torch.randn(128, 3, 64, 64, names=('N', 'C', 'H', 'W'))) 336 337 def test_noncontig_contiguous(self): 338 # This type of contiguous is special-cased and therefore needs its own test 339 for device in get_all_device_types(): 340 x = torch.randn(2, 3, device=device).t().rename_('N', 'C') 341 self.assertEqual(x.contiguous().names, ('N', 'C')) 342 343 def test_copy_transpose(self): 344 # This type of copy is special-cased and therefore needs its own test 345 def _test(self_names, other_names, expected_names): 346 x = torch.empty(2, 5, names=self_names) 347 y = torch.empty(5, 2).t().rename_(*other_names) 348 x.copy_(y) 349 self.assertEqual(x.names, expected_names) 350 351 _test(('N', 'C'), ('N', 'C'), ('N', 'C')) 352 _test(None, ('N', 'C'), ('N', 'C')) 353 354 def test_rename_(self): 355 tensor = torch.empty(1, 1, names=('N', 'C')) 356 self.assertEqual(tensor.rename_(None).names, (None, None)) 357 self.assertEqual(tensor.rename_('H', 'W').names, ('H', 'W')) 358 with self.assertRaisesRegex(RuntimeError, 'Number of names'): 359 tensor.rename_('N', 'C', 'W') 360 with self.assertRaisesRegex(RuntimeError, 'duplicate names'): 361 tensor.rename_('N', 'N') 362 363 def test_rename(self): 364 tensor = torch.empty(1, 1, names=('N', 'C')) 365 366 self.assertEqual(tensor.rename(None).names, (None, None)) 367 self.assertEqual(tensor.rename('H', 'W').names, ('H', 'W')) 368 369 # Check that we didn't modify tensor.names 370 self.assertEqual(tensor.names, ('N', 'C')) 371 372 with self.assertRaisesRegex(RuntimeError, 'Number of names'): 373 tensor.rename('N', 'C', 'W') 374 with self.assertRaisesRegex(RuntimeError, 'duplicate names'): 375 tensor.rename('N', 'N') 376 377 with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'): 378 tensor.rename(None, N='batch') 379 380 # rename returns a view on the tensor 381 self.assertEqual(tensor.rename('H', 'W').data_ptr(), tensor.data_ptr()) 382 self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr()) 383 384 def test_rename_globber(self): 385 scalar = torch.randn([]) 386 unnamed_tensor = torch.empty(1, 1, 1, 1) 387 named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) 388 389 self.assertEqual(scalar.rename(None).names, []) 390 self.assertEqual(scalar.rename('...').names, []) 391 392 # Check that it works with unnamed tensors 393 self.assertEqual(unnamed_tensor.rename('...').names, unnamed_tensor.names) 394 self.assertEqual(unnamed_tensor.rename('...', 'H', 'W').names, 395 [None, None, 'H', 'W']) 396 self.assertEqual(unnamed_tensor.rename('N', '...', 'W').names, 397 ['N', None, None, 'W']) 398 self.assertEqual(unnamed_tensor.rename('N', 'C', '...').names, 399 ['N', 'C', None, None]) 400 401 # Check that it works with named tensors 402 self.assertEqual(named_tensor.rename('...').names, named_tensor.names) 403 self.assertEqual(named_tensor.rename('...', 'width').names, 404 ['N', 'C', 'H', 'width']) 405 self.assertEqual(named_tensor.rename('batch', 'channels', '...', 'width').names, 406 ['batch', 'channels', 'H', 'width']) 407 self.assertEqual(named_tensor.rename('batch', '...').names, 408 ['batch', 'C', 'H', 'W']) 409 410 # Test empty glob 411 self.assertEqual(unnamed_tensor.rename('...', None, None, None, None).names, 412 [None, None, None, None]) 413 self.assertEqual(named_tensor.rename('N', 'C', 'H', '...', 'W').names, 414 ['N', 'C', 'H', 'W']) 415 416 # Multiple globs throw 417 with self.assertRaisesRegex(RuntimeError, 'More than one '): 418 named_tensor.rename('...', 'channels', '...') 419 420 def test_rename_rename_map(self): 421 scalar = torch.randn([]) 422 unnamed_tensor = torch.empty(1, 1, 1, 1) 423 named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) 424 425 with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): 426 scalar.rename(N='batch') 427 with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): 428 unnamed_tensor.rename(N='batch') 429 with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): 430 named_tensor.rename(B='batch') 431 with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): 432 named_tensor.rename(H='height', B='batch') 433 434 self.assertEqual(named_tensor.rename(N='batch').data_ptr(), 435 named_tensor.data_ptr()) 436 self.assertEqual(named_tensor.rename(N='batch').names, 437 ['batch', 'C', 'H', 'W']) 438 self.assertEqual(named_tensor.rename(N='batch', H='height').names, 439 ['batch', 'C', 'height', 'W']) 440 441 def test_set_names_property(self): 442 tensor = torch.empty(1, 1, names=('N', 'C')) 443 444 tensor.names = None 445 self.assertEqual(tensor.names, (None, None)) 446 447 tensor.names = ('N', 'W') 448 self.assertEqual(tensor.names, ('N', 'W')) 449 450 with self.assertRaisesRegex(RuntimeError, 'Number of names'): 451 tensor.names = ['N', 'C', 'W'] 452 with self.assertRaisesRegex(RuntimeError, 'duplicate names'): 453 tensor.names = ['N', 'N'] 454 455 def test_factory_edge_cases(self): 456 for device in get_all_device_types(): 457 self._test_factory(torch.empty, device) 458 459 def test_factory_coverage(self): 460 def _test(factory, device): 461 names = ('N', 'T', 'D') 462 463 torch.manual_seed(0) 464 result = factory(1, 2, 3, names=names, device=device) 465 466 torch.manual_seed(0) 467 expected = factory(1, 2, 3, device=device).rename_(*names) 468 469 self.assertTensorDataAndNamesEqual(result, expected) 470 471 supported = [ 472 torch.ones, 473 torch.rand, 474 torch.randn, 475 torch.zeros, 476 ] 477 478 for op, device in itertools.product(supported, get_all_device_types()): 479 _test(op, device) 480 481 # Test torch.full 482 for device in get_all_device_types(): 483 names = ('N', 'T', 'D') 484 result = torch.full([1, 2, 3], 2., names=names, device=device) 485 expected = torch.full([1, 2, 3], 2., device=device).rename_(*names) 486 self.assertTensorDataAndNamesEqual(result, expected) 487 488 def test_tensor_from_lists(self): 489 names = ('N', 'C') 490 tensor = torch.tensor([[1]], names=names) 491 self.assertEqual(tensor.names, names) 492 493 names = ('N',) 494 tensor = torch.tensor([1], names=names) 495 self.assertEqual(tensor.names, names) 496 497 with self.assertRaisesRegex(RuntimeError, 'Number of names'): 498 names = ('N', 'C') 499 tensor = torch.tensor([1], names=names) 500 501 @unittest.skipIf(not TEST_NUMPY, "no numpy") 502 def test_tensor_from_numpy(self): 503 import numpy as np 504 arr = np.array([[1]]) 505 names = ('N', 'C') 506 tensor = torch.tensor([[1]], names=names) 507 self.assertEqual(tensor.names, names) 508 509 def test_tensor_from_tensor(self): 510 x = torch.randn(1, 1) 511 names = ('N', 'C') 512 tensor = torch.tensor(x, names=names) 513 self.assertEqual(tensor.names, names) 514 515 def test_tensor_from_named_tensor(self): 516 x = torch.randn(1, 1, names=('N', 'D')) 517 tensor = torch.tensor(x) 518 self.assertEqual(tensor.names, ('N', 'D')) 519 520 # there's no way to distinguish between names=None and not passing in names. 521 # If the user passes in names=None they are asking for trouble. 522 x = torch.randn(1, 1, names=('N', 'D')) 523 tensor = torch.tensor(x, names=None) 524 self.assertEqual(tensor.names, ('N', 'D')) 525 526 x = torch.randn(1, 1, names=('N', 'D')) 527 with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 528 tensor = torch.tensor(x, names=('N', 'C')) 529 530 def test_size(self): 531 t = torch.empty(2, 3, 5, names=('N', None, 'C')) 532 self.assertEqual(t.size('N'), 2) 533 self.assertEqual(t.size('C'), 5) 534 with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '): 535 t.size('channels') 536 with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '): 537 torch.empty(2, 3, 4).size('N') 538 539 def test_stride(self): 540 t = torch.empty(2, 3, 5, names=('N', None, 'C')) 541 self.assertEqual(t.stride('N'), 3 * 5) 542 self.assertEqual(t.stride('C'), 1) 543 with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '): 544 t.stride('channels') 545 with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '): 546 torch.empty(2, 3, 4).stride('N') 547 548 def test_transpose_variants(self): 549 t = torch.randn(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) 550 self.assertEqual(t.transpose('N', 'C').names, ['C', 'N', 'H', 'W']) 551 self.assertEqual(t.transpose(1, 3).names, ['N', 'W', 'H', 'C']) 552 553 t = torch.randn(2, 3, names=('N', 'C')) 554 self.assertEqual(t.t().names, ['C', 'N']) 555 556 def test_resize(self): 557 for device in get_all_device_types(): 558 named = torch.randn(2, names=('N',), device=device) 559 named.resize_([2]) 560 self.assertEqual(named.names, ['N']) 561 562 with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"): 563 named.resize_([3]) 564 565 other_named = torch.randn(2, names=('N',), device=device) 566 named.resize_as_(other_named) 567 self.assertEqual(other_named.names, ['N']) 568 569 unnamed = torch.randn(2, device=device) 570 with self.assertRaisesRegex( 571 RuntimeError, r'names .* are not the same as the computed output names'): 572 named.resize_as_(unnamed) 573 574 unnamed = torch.randn(1, device=device) 575 unnamed.resize_as_(named) 576 self.assertEqual(unnamed.names, ['N']) 577 578 def test_cdist(self): 579 for device in get_all_device_types(): 580 tensor = torch.randn(3, 1, 2, 7, names=('M', 'N', 'first_group', 'features'), 581 device=device) 582 other = torch.randn(5, 11, 7, names=('N', 'second_group', 'features'), 583 device=device) 584 result = torch.cdist(tensor, other) 585 self.assertEqual(result.names, ['M', 'N', 'first_group', 'second_group']) 586 587 def test_info_smoke(self): 588 # Smoke test for info functions / methods / attributes on named tensors. 589 tensor = torch.empty(1, 1, names=('N', 'D')) 590 591 tensor.device 592 tensor.dtype 593 tensor.get_device() 594 tensor.is_complex() 595 tensor.is_floating_point() 596 tensor.is_nonzero() 597 torch.is_same_size(tensor, tensor) 598 torch.is_signed(tensor) 599 tensor.layout 600 tensor.numel() 601 tensor.dim() 602 tensor.element_size() 603 tensor.is_contiguous() 604 tensor.is_cuda 605 tensor.is_leaf 606 tensor.is_pinned() 607 tensor.is_shared() 608 tensor.is_sparse 609 tensor.ndimension() 610 tensor.nelement() 611 tensor.shape 612 tensor.size() 613 tensor.size(1) 614 tensor.storage() 615 tensor.storage_offset() 616 tensor.storage_type() 617 tensor.stride() 618 tensor.stride(1) 619 tensor.data 620 tensor.data_ptr() 621 tensor.ndim 622 tensor.item() 623 tensor.type() 624 tensor.is_shared() 625 tensor.is_signed() 626 627 def test_autograd_smoke(self): 628 x = torch.randn(3, 3, names=('N', 'D'), requires_grad=True) 629 630 y = x.clone() 631 y.retain_grad() 632 y.register_hook(lambda x: x) 633 634 y.sum().backward() 635 636 # autograd related attributes 637 tensor = torch.empty(1, 1, names=('N', 'D'), requires_grad=True) 638 tensor = tensor.relu() 639 tensor.output_nr 640 tensor.grad_fn 641 tensor.requires_grad 642 643 def test_split_fns_propagates_names(self): 644 fns = [ 645 lambda x: x.split(1, 0), 646 lambda x: x.split([1, 1], 1), 647 lambda x: x.chunk(2, 0), 648 ] 649 650 for device in get_all_device_types(): 651 orig_tensor = torch.empty(2, 2, names=('N', 'D'), device=device) 652 for fn in fns: 653 splits = fn(orig_tensor) 654 for split in splits: 655 self.assertEqual(split.names, orig_tensor.names) 656 657 def test_any_all(self): 658 for device in get_all_device_types(): 659 x = torch.zeros(3, dtype=torch.bool, device=device, names=('C',)) 660 self.assertEqual(x.any().names, []) 661 self.assertEqual(x.all().names, []) 662 663 def test_addcmul_addcdiv(self): 664 for device in get_all_device_types(): 665 names = ['N'] 666 a = torch.rand(3, device=device, names=names) 667 b = torch.rand(3, device=device, names=names) 668 # avoid division by 0 669 c = torch.rand(3, device=device, names=names).clamp_min_(0.1) 670 out = torch.randn(3, device=device, names=names) 671 672 self.assertEqual(torch.addcmul(a, b, c).names, names) 673 self.assertEqual(torch.addcmul(a, b, c, out=out).names, names) 674 self.assertEqual(a.addcmul_(b, c).names, names) 675 676 self.assertEqual(torch.addcdiv(a, b, c).names, names) 677 self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names) 678 self.assertEqual(a.addcdiv_(b, c).names, names) 679 680 def test_binary_ops(self): 681 def test_basic(op): 682 a = torch.empty(2, 3, names=('N', 'C')) 683 b = torch.empty(3, 2, names=('C', 'N')) 684 c = torch.empty(3, names=('C',)) 685 d = torch.empty(5, names=('W',)) 686 687 self.assertEqual(op(a, a).names, ('N', 'C')) 688 self.assertEqual(op(a, c).names, ('N', 'C')) 689 # TODO: dynamo will throw a slightly different 690 # error message because it's adding fake tensors 691 # `must match the size of` portion is the dynamo error 692 with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"): 693 op(a, d) 694 with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"): 695 op(a, b) 696 697 def test_wildcard(op): 698 a = torch.empty(2, 3, names=('N', 'C')) 699 c = torch.empty(2, 3, names=(None, 'C')) 700 self.assertEqual(op(a, c).names, ('N', 'C')) 701 702 b = torch.empty(2, 3) 703 self.assertEqual(op(a, b).names, ('N', 'C')) 704 705 d = torch.empty(2, 3, names=('C', None)) 706 with self.assertRaisesRegex(RuntimeError, "Misaligned"): 707 op(d, c) 708 709 def test_mixed_unnamed_named(op, is_inplace): 710 named2 = torch.randn(1, 1, names=('N', 'C')) 711 unnamed1 = torch.randn(1) 712 unnamed2 = torch.randn(1, 1) 713 unnamed3 = torch.randn(1, 1, 1) 714 715 def compute_expected_names(tensor, other): 716 assert tensor.has_names() ^ other.has_names() 717 named = tensor if tensor.has_names() else other 718 unnamed = other if tensor.has_names() else tensor 719 unnamed_dim = unnamed.dim() 720 if unnamed_dim > named.dim(): 721 return [None] * (unnamed_dim - named.dim()) + list(named.names) 722 else: 723 return named.names 724 725 inputs = itertools.chain( 726 itertools.product([named2], [unnamed1, unnamed2, unnamed3]), 727 itertools.product([unnamed1, unnamed2, unnamed3], [named2]), 728 ) 729 if is_inplace: 730 # In-place ops have the constraint that they must not change shape. 731 inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()] 732 733 for tensor, other in inputs: 734 expected_names = compute_expected_names(tensor, other) 735 self.assertEqual(op(tensor, other).names, expected_names) 736 737 def method(name, *args, **kwargs): 738 return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))] 739 740 def function(name, *args, **kwargs): 741 return [Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))] 742 743 def out_function(name, *args, **kwargs): 744 out_fn = getattr(torch, name) 745 746 def fn(a, b): 747 result = torch.empty([0], dtype=a.dtype, device=a.device) 748 out_fn(a, b, *args, out=result, **kwargs) 749 return result 750 751 return [Function(name, fn)] 752 753 def fn_method_and_inplace(name, *args, **kwargs): 754 return ( 755 method(name, *args, **kwargs) + 756 method(name + '_', *args, **kwargs) + 757 out_function(name, *args, **kwargs) 758 ) 759 760 tests = [ 761 fn_method_and_inplace('add'), 762 fn_method_and_inplace('div'), 763 fn_method_and_inplace('mul'), 764 fn_method_and_inplace('sub'), 765 fn_method_and_inplace('pow'), 766 fn_method_and_inplace('atan2'), 767 method('copy_'), 768 function('floor_divide'), 769 function('true_divide'), 770 ] 771 tests = flatten(tests) 772 773 for name, op in tests: 774 test_basic(op) 775 test_wildcard(op) 776 test_mixed_unnamed_named(op, is_inplace=name.endswith('_')) 777 778 def test_logical_ops(self): 779 # Implemented via TensorIterator, so just check that each version 780 # (out-of-place, inplace, out=) propagates names. 781 def zeros(*args, **kwargs): 782 return torch.zeros(*args, dtype=torch.bool, **kwargs) 783 784 for op in ('logical_xor', 'logical_and', 'logical_or'): 785 self._test_name_inference( 786 getattr(torch, op), 787 (create('N:2,C:3', zeros), create('N:2,C:3', zeros)), 788 expected_names=['N', 'C']) 789 790 self._test_name_inference( 791 getattr(Tensor, op + '_'), 792 (create('N:2,C:3', zeros), create('N:2,C:3', zeros)), 793 expected_names=['N', 'C']) 794 795 self._test_name_inference( 796 lambda out, x, y: getattr(torch, op)(x, y, out=out), 797 (create('0', zeros), create('N:2,C:3', zeros), create('N:2,C:3', zeros)), 798 expected_names=['N', 'C']) 799 800 def test_pow_special(self): 801 # There are a few pow cases that don't go through TensorIterator. 802 # Test them here. 803 for device in get_all_device_types(): 804 named = torch.randn(2, 3, names=('N', 'C'), device=device) 805 unnamed = torch.randn([0], device=device) 806 807 result = torch.pow(named, 0, out=unnamed.clone()) 808 self.assertEqual(result.names, named.names) 809 810 result = torch.pow(named, 1, out=unnamed.clone()) 811 self.assertEqual(result.names, named.names) 812 813 result = torch.pow(1, named, out=unnamed.clone()) 814 self.assertEqual(result.names, named.names) 815 816 def test_out_fn_semantics(self): 817 out_fn = torch.abs 818 unnamed_tensor = torch.randn(3, 2) 819 none_named_tensor = torch.randn(3, 2, names=(None, None)) 820 named_tensor = torch.randn(3, 2, names=('N', 'C')) 821 partially_named_tensor = torch.randn(3, 2, names=('N', None)) 822 823 with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 824 out_fn(partially_named_tensor, out=named_tensor) 825 with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 826 out_fn(named_tensor, out=partially_named_tensor) 827 with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 828 out_fn(none_named_tensor, out=named_tensor) 829 with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 830 out_fn(unnamed_tensor, out=named_tensor) 831 832 output = torch.randn(3, 2) 833 out_fn(unnamed_tensor, out=output) 834 self.assertFalse(output.has_names()) 835 836 output = torch.randn(3, 2, names=(None, None)) 837 out_fn(named_tensor, out=output) 838 self.assertEqual(output.names, named_tensor.names) 839 840 output = torch.randn(3, 2) 841 out_fn(named_tensor, out=output) 842 self.assertEqual(output.names, named_tensor.names) 843 844 output = torch.randn(3, 2, names=(None, None)) 845 out_fn(unnamed_tensor, out=output) 846 self.assertFalse(output.has_names()) 847 848 def test_unary_propagate_names_fns(self): 849 def _test(testcase, names=('N', 'D'), device='cpu'): 850 sizes = [2] * len(names) 851 tensor = torch.empty(sizes, names=names, device=device) 852 try: 853 out = testcase.lambd(tensor) 854 except RuntimeError as err: 855 # Get a better error message by catching the error and asserting. 856 raise RuntimeError(f'{testcase.name}: {err}') from err 857 self.assertEqual(out.names, tensor.names, 858 msg=testcase.name) 859 860 def fn(name, *args, **kwargs): 861 return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))] 862 863 def method(name, *args, **kwargs): 864 return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))] 865 866 def out_function(name, *args, **kwargs): 867 out_fn = getattr(torch, name) 868 869 def fn(tensor): 870 result = torch.empty([0], dtype=tensor.dtype, device=tensor.device) 871 out_fn(tensor, *args, out=result, **kwargs) 872 return result 873 874 return [Function(name + '_out', fn)] 875 876 def fn_method_and_inplace(name, *args, **kwargs): 877 return ( 878 method(name, *args, **kwargs) + 879 method(name + '_', *args, **kwargs) + 880 out_function(name, *args, **kwargs) 881 ) 882 883 # All of these operate on 2x2 tensors. 884 tests = [ 885 # unary pointwise 886 fn_method_and_inplace('abs'), 887 fn_method_and_inplace('acos'), 888 fn_method_and_inplace('asin'), 889 fn_method_and_inplace('atan'), 890 fn_method_and_inplace('ceil'), 891 fn_method_and_inplace('clamp', -1, 1), 892 fn_method_and_inplace('clamp_min', -2), 893 fn_method_and_inplace('clamp_max', 2), 894 method('cauchy_'), 895 method('clone'), 896 method('contiguous'), 897 fn_method_and_inplace('cos'), 898 fn_method_and_inplace('cosh'), 899 fn_method_and_inplace('digamma'), 900 fn_method_and_inplace('erf'), 901 fn_method_and_inplace('erfc'), 902 fn_method_and_inplace('erfinv'), 903 fn_method_and_inplace('exp'), 904 fn_method_and_inplace('expm1'), 905 method('exponential_'), 906 fn_method_and_inplace('floor'), 907 fn_method_and_inplace('frac'), 908 method('geometric_', p=0.5), 909 fn_method_and_inplace('lgamma'), 910 fn_method_and_inplace('log'), 911 fn_method_and_inplace('log10'), 912 fn_method_and_inplace('log1p'), 913 fn_method_and_inplace('log2'), 914 method('log_normal_'), 915 fn_method_and_inplace('neg'), 916 method('normal_'), 917 [Function('polygamma', lambda t: torch.polygamma(1, t))], 918 method('polygamma_', 1), 919 fn_method_and_inplace('reciprocal'), 920 method('random_', 0, 1), 921 method('random_', 1), 922 method('random_'), 923 method('relu_'), 924 method('requires_grad_'), 925 method('relu'), 926 fn_method_and_inplace('round'), 927 fn_method_and_inplace('rsqrt'), 928 fn_method_and_inplace('sigmoid'), 929 fn_method_and_inplace('sign'), 930 fn_method_and_inplace('sin'), 931 fn_method_and_inplace('sinh'), 932 fn_method_and_inplace('sqrt'), 933 fn_method_and_inplace('tan'), 934 fn_method_and_inplace('tanh'), 935 fn('threshold', 0, 1), 936 fn('threshold_', 0, 1), 937 out_function('threshold', 0, 1), 938 fn_method_and_inplace('trunc'), 939 method('uniform_'), 940 method('zero_'), 941 method('fill_', 1), 942 method('fill_', torch.tensor(3.14)), 943 944 # conversions 945 method('to', dtype=torch.long), 946 method('to', device='cpu'), 947 method('to', torch.empty([])), 948 method('bool'), 949 method('byte'), 950 method('char'), 951 method('cpu'), 952 method('double'), 953 method('float'), 954 method('long'), 955 method('half'), 956 method('int'), 957 method('short'), 958 method('type', dtype=torch.long), 959 960 # cumsum and cumprod 961 fn('cumsum', 0), 962 fn('cumsum', 'D'), 963 out_function('cumsum', 'D'), 964 fn('cumprod', 0), 965 fn('cumprod', 'D'), 966 out_function('cumprod', 'D'), 967 968 # views 969 method('narrow', 0, 0, 1), 970 971 # creation functions 972 fn('empty_like'), 973 fn('zeros_like'), 974 fn('ones_like'), 975 fn('full_like', 3.14), 976 fn('rand_like'), 977 fn('randn_like'), 978 979 # bernoulli variants 980 method('bernoulli_', 0.5), 981 method('bernoulli_', torch.tensor(0.5)), 982 983 method('softmax', dim=1), 984 method('softmax', dim='D'), 985 method('log_softmax', dim=1), 986 method('log_softmax', dim='D'), 987 988 [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))], 989 [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))], 990 ] 991 tests = flatten(tests) 992 993 for testcase, device in itertools.product(tests, get_all_device_types()): 994 _test(testcase, device=device) 995 996 def test_cummax_cummin(self): 997 def test_ops(op): 998 for device in get_all_device_types(): 999 names = ('N', 'D') 1000 tensor = torch.rand(2, 3, names=names) 1001 result = op(tensor, 0) 1002 self.assertEqual(result[0].names, names) 1003 self.assertEqual(result[1].names, names) 1004 test_ops(torch.cummax) 1005 test_ops(torch.cummin) 1006 1007 def test_logcumsumexp(self): 1008 for device in get_all_device_types(): 1009 names = ('N', 'D') 1010 tensor = torch.rand(2, 3, names=names) 1011 result = torch.logcumsumexp(tensor, 'D') 1012 self.assertEqual(result.names, names) 1013 1014 def test_bitwise_not(self): 1015 for device in get_all_device_types(): 1016 names = ('N', 'D') 1017 tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) 1018 result = torch.empty(0, dtype=torch.bool) 1019 1020 self.assertEqual(tensor.bitwise_not().names, names) 1021 self.assertEqual(torch.bitwise_not(tensor, out=result).names, names) 1022 self.assertEqual(tensor.bitwise_not_().names, names) 1023 1024 def test_logical_not(self): 1025 for device in get_all_device_types(): 1026 names = ('N', 'D') 1027 tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) 1028 result = torch.empty(0, dtype=torch.bool) 1029 1030 self.assertEqual(tensor.logical_not().names, names) 1031 self.assertEqual(torch.logical_not(tensor, out=result).names, names) 1032 self.assertEqual(tensor.logical_not_().names, names) 1033 1034 def test_bernoulli(self): 1035 for device in get_all_device_types(): 1036 names = ('N', 'D') 1037 tensor = torch.rand(2, 3, names=names) 1038 result = torch.empty(0) 1039 self.assertEqual(tensor.bernoulli().names, names) 1040 1041 torch.bernoulli(tensor, out=result) 1042 self.assertEqual(result.names, names) 1043 1044 def test_flatten(self): 1045 tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W')) 1046 1047 # basic 1048 out = tensor.flatten('D', 'W', 'features') 1049 self.assertEqual(out.names, ['N', 'C', 'features']) 1050 self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) 1051 1052 # int overload 1053 out = tensor.flatten(2, 4, 'features') 1054 self.assertEqual(out.names, ['N', 'C', 'features']) 1055 self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) 1056 1057 # list overload 1058 out = tensor.flatten(['D', 'H', 'W'], 'features') 1059 self.assertEqual(out.names, ['N', 'C', 'features']) 1060 self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) 1061 1062 # Non-contiguous flatten: N and H are not "adjacent" in memory. 1063 sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D')) 1064 sentences = sentences.transpose('T', 'H') 1065 out = sentences.flatten('N', 'H', 'N_H') 1066 self.assertEqual(out.names, ['N_H', 'T', 'D']) 1067 1068 with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"): 1069 tensor.flatten(['D', 'L'], 'features') 1070 1071 with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): 1072 tensor.flatten(['D', 'W'], 'features') 1073 1074 with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): 1075 tensor.flatten(['H', 'D', 'W'], 'features') 1076 1077 def test_flatten_nodims(self): 1078 tensor = torch.empty((2, 3)) 1079 with self.assertRaisesRegex(RuntimeError, "cannot be empty"): 1080 tensor.flatten((), 'abcd') 1081 1082 def test_flatten_index_error(self): 1083 tensor = torch.randn(1, 2) 1084 with self.assertRaisesRegex(IndexError, 1085 r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"): 1086 tensor.flatten(0, 2) 1087 with self.assertRaisesRegex(IndexError, 1088 r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"): 1089 tensor.flatten(0, 2, 'N') 1090 with self.assertRaisesRegex(RuntimeError, 1091 r"flatten\(\) has invalid args: start_dim cannot come after end_dim"): 1092 tensor.flatten(1, 0) 1093 with self.assertRaisesRegex(RuntimeError, 1094 r"flatten\(\) has invalid args: start_dim cannot come after end_dim"): 1095 tensor.flatten(1, 0, 'N') 1096 1097 def test_unflatten(self): 1098 # test args: tensor, int, namedshape 1099 self.assertTrue(torch.equal( 1100 torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))), 1101 torch.ones(2, 2, names=('A', 'B')))) 1102 self.assertTrue(torch.equal( 1103 torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]), 1104 torch.ones(2, 2, names=('A', 'B')))) 1105 self.assertTrue(torch.equal( 1106 torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])), 1107 torch.ones(2, 2, names=('A', 'B')))) 1108 self.assertTrue(torch.equal( 1109 torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)), 1110 torch.ones(2, 10, names=('A', 'B1')))) 1111 self.assertTrue(torch.equal( 1112 torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B')) 1113 .unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])), 1114 torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4')))) 1115 self.assertTrue(torch.equal( 1116 torch.ones(2, 0, names=('A', 'B')) 1117 .unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])), 1118 torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3')))) 1119 1120 # test args: namedtensor, str, namedshape 1121 self.assertTrue(torch.equal( 1122 torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))), 1123 torch.ones(2, 2, 2, names=('A', 'B1', 'B2')))) 1124 1125 # test invalid args: namedtensor, str, sizes 1126 with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"): 1127 torch.tensor([1], names=('A',)).unflatten('A', (1, 1)) 1128 1129 # test invalid args: namedtensor, int, sizes 1130 with self.assertRaisesRegex(RuntimeError, r"input is a named tensor but no names were given for unflattened sizes"): 1131 torch.tensor([1], names=("A",)).unflatten(0, (1, 1)) 1132 1133 with self.assertRaisesRegex(RuntimeError, 1134 r"Provided sizes \[3, -1\] don't multiply up to the " 1135 r"size of dim 1 \('B': 4\) in Tensor\['A', 'B'\]"): 1136 torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 3), ('B2', -1))) 1137 1138 with self.assertRaisesRegex(RuntimeError, 1139 r"the unspecified dimension size -1 can be any value and is ambiguous"): 1140 torch.ones(2, 0, names=('A', 'B')).unflatten('B', (('B1', 0), ('B2', -1))) 1141 1142 tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K')) 1143 1144 # accepts OrderedDict 1145 out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5)))) 1146 self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K')) 1147 self.assertEqual(out.shape, (7, 2, 3, 5, 11)) 1148 1149 # Unflatten left-most 1150 out = tensor.unflatten('N', (('N', 7), ('H', 1))) 1151 self.assertEqual(out.names, ('N', 'H', 'D', 'K')) 1152 self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11)) 1153 1154 # Unflatten right-most 1155 out = tensor.unflatten('K', (('K', 11), ('H', 1))) 1156 self.assertEqual(out.names, ('N', 'D', 'K', 'H')) 1157 self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1)) 1158 1159 with self.assertRaisesRegex(RuntimeError, "don't multiply up to"): 1160 tensor.unflatten('D', (('H', 3), ('W', 5))) 1161 1162 with self.assertRaisesRegex(RuntimeError, 'sizes must be non-empty'): 1163 tensor.unflatten('D', None) 1164 1165 with self.assertRaisesRegex(RuntimeError, 'non-empty'): 1166 tensor.unflatten('D', OrderedDict()) 1167 1168 def test_unsupported_op_error_msg(self): 1169 named = torch.randn(3, 3, names=('N', 'C')) 1170 with self.assertRaisesRegex( 1171 RuntimeError, r"pdist.+is not yet supported with named tensors"): 1172 torch.pdist(named) 1173 with self.assertRaisesRegex( 1174 RuntimeError, r"as_strided_.+is not yet supported with named tensors"): 1175 named.as_strided_((3, 3), (3, 1)) 1176 1177 def test_reduction_fns(self): 1178 def check_output(output, expected_names): 1179 if isinstance(output, torch.Tensor): 1180 self.assertEqual(output.names, expected_names) 1181 return 1182 for out in output: 1183 self.assertEqual(out.names, expected_names) 1184 1185 def sum_all_outputs(output): 1186 if isinstance(output, torch.Tensor): 1187 return output.sum() 1188 result = 0 1189 for out in output: 1190 result = out + result 1191 return result.sum() 1192 1193 def test_simple_reduce(op, device): 1194 t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1195 check_output(op(t, 1), ['N', 'L']) 1196 check_output(op(t, -1), ['N', 'C']) 1197 check_output(op(t, 'C'), ['N', 'L']) 1198 ops_support_dim_none = [ 1199 'sum', 1200 'mean', 1201 'std', 1202 'var', 1203 'std_mean', 1204 'var_mean', 1205 'nanmean', 1206 'nansum', 1207 ] 1208 if op.__name__ in ops_support_dim_none: 1209 check_output(op(t, None), []) 1210 else: 1211 with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): 1212 op(t, None) 1213 with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'): 1214 op(t, 'H') 1215 1216 def test_autograd_supports_dimname_overload(op, device): 1217 t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device, requires_grad=True) 1218 sum_all_outputs(op(t, 'C')).backward() 1219 self.assertIsNotNone(t.grad) 1220 1221 def test_complete_reduce(op, device): 1222 t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1223 check_output(op(t), []) 1224 1225 def test_multidim_reduce(op, device): 1226 t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1227 1228 check_output(op(t, [1, 2]), ['N']) 1229 check_output(op(t, [0, -1]), ['C']) 1230 check_output(op(t, ['C', 'L']), ['N']) 1231 with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): 1232 op(t, [None, 'C']) 1233 1234 def test_out_variant(op, output_lambda, device): 1235 t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1236 if output_lambda: 1237 out = output_lambda(t) 1238 else: 1239 out = torch.empty([0], device=device) 1240 op(t, 'C', out=out) 1241 check_output(out, ['N', 'L']) 1242 1243 def test_keepdim(op, device): 1244 t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1245 check_output(op(t, 'C', keepdim=True), ['N', 'C', 'L']) 1246 1247 def values_and_indices(t): 1248 return (torch.empty([0], device=t.device), 1249 torch.empty([0], device=t.device, dtype=torch.long)) 1250 1251 def kthvalue_wrapper(tensor, *args, **kwargs): 1252 # Return the 0-th value 1253 return torch.kthvalue(tensor, 1, *args, **kwargs) 1254 1255 Case = namedtuple('Case', [ 1256 'op', 1257 'supports_complete_reduce', 1258 'supports_multidim_reduce', 1259 'supports_out_variant', 1260 'supports_keepdim', 1261 'output_lambda', 1262 ]) 1263 1264 tests = [ 1265 Case(torch.sum, True, True, True, True, None), 1266 Case(torch.prod, True, False, True, True, None), 1267 Case(torch.mean, True, True, True, True, None), 1268 Case(torch.var, True, True, True, True, None), 1269 Case(torch.std, True, True, True, True, None), 1270 Case(torch.std_mean, True, True, False, True, None), 1271 Case(torch.var_mean, True, True, False, True, None), 1272 Case(torch.min, True, False, True, True, values_and_indices), 1273 Case(torch.max, True, False, True, True, values_and_indices), 1274 Case(torch.unbind, False, False, False, False, None), 1275 Case(torch.logsumexp, False, True, True, True, None), 1276 Case(torch.mode, False, False, True, True, values_and_indices), 1277 Case(kthvalue_wrapper, False, False, True, True, values_and_indices), 1278 Case(torch.median, True, False, True, True, values_and_indices), 1279 Case(torch.nanmedian, True, False, True, True, values_and_indices), 1280 ] 1281 1282 for testcase, device in itertools.product(tests, get_all_device_types()): 1283 op = testcase.op 1284 test_simple_reduce(op, device) 1285 test_autograd_supports_dimname_overload(op, device) 1286 1287 if testcase.supports_keepdim: 1288 test_keepdim(op, device) 1289 if testcase.supports_out_variant: 1290 test_out_variant(op, testcase.output_lambda, device) 1291 if testcase.supports_complete_reduce: 1292 test_complete_reduce(op, device) 1293 if testcase.supports_multidim_reduce: 1294 test_multidim_reduce(op, device) 1295 1296 def test_masked_select(self): 1297 # simple 1298 self._test_name_inference( 1299 torch.masked_select, 1300 (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')), 1301 expected_names=[None]) 1302 1303 # left broadcast 1304 self._test_name_inference( 1305 torch.masked_select, 1306 (create('C:3'), (create('2,3') > 0).rename('N', 'C')), 1307 expected_names=[None]) 1308 1309 # right broadcast 1310 self._test_name_inference( 1311 torch.masked_select, 1312 (create('N:2,C:3'), (create('3') > 0).rename('C')), 1313 expected_names=[None]) 1314 1315 # error 1316 self._test_name_inference( 1317 torch.masked_select, 1318 (create('N:2,C:3'), (create('3') > 0).rename('D')), 1319 maybe_raises_regex='do not match') 1320 1321 # out= 1322 self._test_name_inference( 1323 out_fn(torch.masked_select), 1324 (create('0'), create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')), 1325 expected_names=[None]) 1326 1327 def test_cat(self): 1328 # simple 1329 self._test_name_inference( 1330 torch.cat, 1331 [[create('N:2,C:3'), create('N:2,C:3')]], 1332 expected_names=['N', 'C']) 1333 1334 # error: zero dim 1335 self._test_name_inference( 1336 torch.cat, 1337 [[create(''), create('')]], 1338 maybe_raises_regex='zero-dim') 1339 1340 # error: names don't match 1341 self._test_name_inference( 1342 torch.cat, 1343 [[create('N:2,C:3'), create('C:3,N:2')]], 1344 maybe_raises_regex='do not match') 1345 1346 # error: different number of dims 1347 self._test_name_inference( 1348 torch.cat, 1349 [[create('N:2,C:3'), create('C:3')]], 1350 maybe_raises_regex='must have same number of dimensions') 1351 1352 # out= 1353 self._test_name_inference( 1354 out_fn(torch.cat), 1355 [create('0'), [create('N:2,C:3'), create('N:2,C:3')]], 1356 expected_names=['N', 'C']) 1357 1358 def test_masked_fill(self): 1359 # simple 1360 self._test_name_inference( 1361 Tensor.masked_fill, 1362 (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1363 expected_names=['N', 'C']) 1364 1365 # left broadcast 1366 self._test_name_inference( 1367 Tensor.masked_fill, 1368 (create('C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1369 maybe_raises_regex="must be less than or equal to") 1370 1371 # right broadcast 1372 self._test_name_inference( 1373 Tensor.masked_fill, 1374 (create('N:2,C:3'), (create('3') > 0).rename('C'), 3.14), 1375 expected_names=['N', 'C']) 1376 1377 # error 1378 self._test_name_inference( 1379 Tensor.masked_fill, 1380 (create('N:2,C:3'), (create('3') > 0).rename('D'), 3.14), 1381 maybe_raises_regex='do not match') 1382 1383 # inplace 1384 self._test_name_inference( 1385 Tensor.masked_fill_, 1386 (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1387 expected_names=['N', 'C']) 1388 1389 # inplace, computed names don't match output tensor names 1390 self._test_name_inference( 1391 Tensor.masked_fill_, 1392 (create('N:2,None:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1393 maybe_raises_regex="not the same as the computed output names") 1394 1395 1396 def test_using_seen_interned_string_doesnt_bump_refcount(self): 1397 def see_name(): 1398 seen_name = 'N' 1399 pass_name_to_python_arg_parser(seen_name) 1400 1401 see_name() 1402 seen_name = 'N' 1403 old_refcnt = sys.getrefcount(seen_name) 1404 1405 pass_name_to_python_arg_parser(seen_name) 1406 1407 new_refcnt = sys.getrefcount(seen_name) 1408 self.assertEqual(new_refcnt, old_refcnt) 1409 1410 # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464 1411 @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+") 1412 def test_using_unseen_interned_string_bumps_refcount_permanently(self): 1413 # Please don't use this as a name in a different test. 1414 unseen_name = 'abcdefghi' 1415 old_refcnt = sys.getrefcount(unseen_name) 1416 1417 pass_name_to_python_arg_parser(unseen_name) 1418 1419 new_refcnt = sys.getrefcount(unseen_name) 1420 self.assertEqual(new_refcnt, old_refcnt + 1) 1421 1422 # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464 1423 @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+") 1424 def test_using_unseen_uninterned_string_refcounts(self): 1425 # Please don't use this as a name in a different test. 1426 # non-compile-time constants are not interned 1427 unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl']) 1428 interned_unseen_name = 'abcdefghijkl' 1429 self.assertFalse(unseen_name is interned_unseen_name) 1430 1431 old_uninterned_refcnt = sys.getrefcount(unseen_name) 1432 old_interned_refcnt = sys.getrefcount(interned_unseen_name) 1433 1434 pass_name_to_python_arg_parser(unseen_name) 1435 1436 new_uninterned_refcnt = sys.getrefcount(unseen_name) 1437 new_interned_refcnt = sys.getrefcount(interned_unseen_name) 1438 1439 # Internally, PyTorch should not hold a reference to the uninterned string 1440 self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt) 1441 1442 # Instead, we should hold a new reference to the interned version. 1443 self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1) 1444 1445 def _test_select(self, device): 1446 x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device) 1447 y = x.select(1, 1) 1448 self.assertEqual(y.names, ('N', 'H', 'W')) 1449 1450 y = x.select('C', 1) 1451 self.assertEqual(y.names, ('N', 'H', 'W')) 1452 1453 with self.assertRaisesRegex( 1454 RuntimeError, 'Please look up dimensions by name'): 1455 y = x.select(None, 1) 1456 1457 def test_select(self): 1458 self._test_select('cpu') 1459 1460 @unittest.skipIf(not TEST_CUDA, 'no CUDA') 1461 def test_select_cuda(self): 1462 self._test_select('cuda') 1463 1464 def _test_as_strided(self, device): 1465 x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device) 1466 y = x.as_strided([2 * 3 * 4 * 5], [1]) 1467 self.assertEqual(y.names, (None,)) 1468 1469 def test_as_strided(self): 1470 self._test_as_strided('cpu') 1471 1472 @unittest.skipIf(not TEST_CUDA, 'no CUDA') 1473 def test_as_strided_cuda(self): 1474 self._test_as_strided('cuda') 1475 1476 def test_no_jit_tracer_support(self): 1477 def foo(x): 1478 return torch.full(x.shape, 2., names=('N',)) 1479 1480 with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): 1481 x = torch.randn(3) 1482 torch.jit.trace(foo, example_inputs=x) 1483 1484 def bar(x): 1485 return x.select('N', 1) 1486 1487 with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): 1488 x = torch.randn(3) 1489 torch.jit.trace(bar, example_inputs=x) 1490 1491 def test_no_jit_script_support(self): 1492 @torch.jit.script 1493 def foo(x): 1494 return x + 1 1495 1496 with self.assertRaisesRegex(RuntimeError, 'NYI'): 1497 foo(torch.randn(2, 3, names=('N', 'C'))) 1498 1499 @torch.jit.ignore 1500 def add_names(x): 1501 x.names = ('N', 'C') 1502 1503 @torch.jit.script 1504 def return_named_tensor(input): 1505 add_names(input) 1506 return input 1507 1508 with self.assertRaisesRegex(RuntimeError, "NYI"): 1509 return_named_tensor(torch.randn(1, 1)) 1510 1511 def test_align_to(self): 1512 # trivial 1513 tensor = create('N:3') 1514 output = tensor.align_to('N') 1515 self.assertEqual(output.names, ['N']) 1516 self.assertEqual(output.shape, [3]) 1517 1518 # unsqueeze behavior 1519 tensor = create('N:3') 1520 output = tensor.align_to('N', 'D') 1521 self.assertEqual(output.names, ['N', 'D']) 1522 self.assertEqual(output.shape, [3, 1]) 1523 1524 # transpose behavior 1525 tensor = create('N:3,C:2') 1526 output = tensor.align_to('C', 'N') 1527 self.assertEqual(output.names, ['C', 'N']) 1528 self.assertEqual(output.shape, [2, 3]) 1529 1530 # unsqueeze / transpose 1531 tensor = create('C:2,N:3,H:5') 1532 output = tensor.align_to('N', 'H', 'W', 'C') 1533 self.assertEqual(output.names, ['N', 'H', 'W', 'C']) 1534 self.assertEqual(output.shape, [3, 5, 1, 2]) 1535 1536 # All input dimensions must be named 1537 with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"): 1538 create('None:2,C:3').align_to('N', 'C') 1539 1540 # not enough names 1541 with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"): 1542 create('N:2,C:3').align_to('C') 1543 1544 # names not found 1545 with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"): 1546 create('N:2,C:3').align_to('D', 'N') 1547 1548 def test_align_to_ellipsis(self): 1549 tensor = create('N:7,H:3,W:5,C:2') 1550 1551 # ... = ['N', 'H', 'W', 'C'] 1552 output = tensor.align_to('...') 1553 self.assertEqual(output.names, ['N', 'H', 'W', 'C']) 1554 self.assertEqual(output.shape, [7, 3, 5, 2]) 1555 1556 # ... = ['H', 'C'] 1557 output = tensor.align_to('...', 'W', 'N') 1558 self.assertEqual(output.names, ['H', 'C', 'W', 'N']) 1559 self.assertEqual(output.shape, [3, 2, 5, 7]) 1560 1561 # ... = ['N', 'W'] 1562 output = tensor.align_to('H', 'C', '...') 1563 self.assertEqual(output.names, ['H', 'C', 'N', 'W']) 1564 self.assertEqual(output.shape, [3, 2, 7, 5]) 1565 1566 # ... = ['H', 'C'] 1567 output = tensor.align_to('W', '...', 'N') 1568 self.assertEqual(output.names, ['W', 'H', 'C', 'N']) 1569 self.assertEqual(output.shape, [5, 3, 2, 7]) 1570 1571 # ... = [] 1572 output = tensor.align_to('N', '...', 'C', 'D', 'H', 'W') 1573 self.assertEqual(output.names, ['N', 'C', 'D', 'H', 'W']) 1574 self.assertEqual(output.shape, [7, 2, 1, 3, 5]) 1575 1576 # Input tensor partially named 1577 partially_named = create('None:2,None:3,None:5,C:7') 1578 output = partially_named.align_to('C', '...') 1579 self.assertEqual(output.names, ['C', None, None, None]) 1580 self.assertEqual(output.shape, [7, 2, 3, 5]) 1581 1582 with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"): 1583 partially_named.align_to('C', None, '...') 1584 1585 # Input order partially named 1586 with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"): 1587 tensor.align_to('...', 'N', None) 1588 1589 # Input order duplicate names 1590 with self.assertRaisesRegex(RuntimeError, "duplicate names"): 1591 tensor.align_to('...', 'N', 'N') 1592 1593 def test_align_as(self): 1594 # align_as calls align_to internally. align_to has pretty substantial tests, 1595 # so just test some basic things here. 1596 tensor = create('C:2,N:3,H:5') 1597 other = create('N:1,H:1,W:1,C:1') 1598 output = tensor.align_as(other) 1599 self.assertEqual(output.names, ['N', 'H', 'W', 'C']) 1600 self.assertEqual(output.shape, [3, 5, 1, 2]) 1601 1602 @unittest.skip("Not implemented yet") 1603 def test_align_tensors_two_inputs(self): 1604 def _test(tensor_namedshape, align_names, expected_sizes, expected_error): 1605 tensor_names, tensor_sizes = tensor_namedshape 1606 tensor = torch.empty(*tensor_sizes, names=tensor_names) 1607 other = torch.empty([1] * len(align_names), names=align_names) 1608 if expected_error is not None: 1609 with self.assertRaisesRegex(RuntimeError, expected_error): 1610 torch.align_tensors(tensor, other) 1611 return 1612 1613 output, _ = torch.align_tensors(tensor, other) 1614 self.assertEqual(output.shape, expected_sizes) 1615 self.assertEqual(output.names, align_names) 1616 1617 Case = namedtuple('Case', [ 1618 'tensor_namedshape', 1619 'align_names', 1620 'expected_sizes', 1621 'expected_error', 1622 ]) 1623 1624 tests = [ 1625 # basic tests 1626 Case(tensor_namedshape=(['C'], [2]), 1627 align_names=['C'], 1628 expected_sizes=[2], 1629 expected_error=None), 1630 Case(tensor_namedshape=(['C'], [2]), 1631 align_names=['D'], 1632 expected_sizes=None, 1633 expected_error='not a subsequence'), 1634 1635 # single-dim alignment test 1636 Case(tensor_namedshape=(['C'], [2]), 1637 align_names=['N', 'C'], 1638 expected_sizes=[1, 2], 1639 expected_error=None), 1640 Case(tensor_namedshape=[['N'], [2]], 1641 align_names=['N', 'C'], 1642 expected_sizes=[2, 1], 1643 expected_error=None), 1644 1645 # multiple dim alignment test 1646 Case(tensor_namedshape=[['N', 'C'], [2, 3]], 1647 align_names=['N', 'H', 'C', 'W'], 1648 expected_sizes=[2, 1, 3, 1], 1649 expected_error=None), 1650 Case(tensor_namedshape=[['N', 'C'], [2, 3]], 1651 align_names=['C', 'H', 'N', 'W'], 1652 expected_sizes=None, 1653 expected_error='not a subsequence'), 1654 1655 # scalar tensor tests 1656 Case(tensor_namedshape=[None, [[]]], 1657 align_names=['N', 'C'], 1658 expected_sizes=[1, 1], 1659 expected_error=None), 1660 Case(tensor_namedshape=[[], [[]]], 1661 align_names=[None, None], 1662 expected_sizes=[1, 1], 1663 expected_error=None), 1664 1665 # unnamed tensor tests 1666 Case(tensor_namedshape=[None, [2, 3]], 1667 align_names=[None, None], 1668 expected_sizes=[2, 3], 1669 expected_error=None), 1670 Case(tensor_namedshape=[None, [2, 3]], 1671 align_names=[None, None, None], 1672 expected_sizes=[1, 2, 3], 1673 expected_error=None), 1674 Case(tensor_namedshape=[None, [2]], 1675 align_names=['N'], 1676 expected_sizes=None, 1677 expected_error='not a subsequence'), 1678 1679 # unnamed dim alignment tests 1680 Case(tensor_namedshape=[[None], [2]], 1681 align_names=['N', None], 1682 expected_sizes=[1, 2], 1683 expected_error=None), 1684 Case(tensor_namedshape=[[None], [2]], 1685 align_names=['N', None, None, None], 1686 expected_sizes=[1, 1, 1, 2], 1687 expected_error=None), 1688 Case(tensor_namedshape=[['N'], [2]], 1689 align_names=['N', None, None, None], 1690 expected_sizes=[2, 1, 1, 1], 1691 expected_error=None), 1692 Case(tensor_namedshape=[[None, 'N', None], [2, 3, 5]], 1693 align_names=[None, None, 'N', None], 1694 expected_sizes=[1, 2, 3, 5], 1695 expected_error=None), 1696 Case(tensor_namedshape=[[None], [2]], 1697 align_names=[None, 'N'], 1698 expected_sizes=None, 1699 expected_error='absolute position from the right'), 1700 Case(tensor_namedshape=[None, [2]], 1701 align_names=[None, 'N'], 1702 expected_sizes=None, 1703 expected_error='absolute position from the right'), 1704 Case(tensor_namedshape=[[None, 'N'], [2, 3]], 1705 align_names=[None, 'C', 'N'], 1706 expected_sizes=None, 1707 expected_error='absolute position from the right'), 1708 ] 1709 1710 for test in tests: 1711 _test(*test) 1712 1713 @unittest.skip("Not implemented yet") 1714 def test_align_tensors(self): 1715 def reference_fn(*tensors): 1716 longest_names = tensors[0].names 1717 for tensor in tensors: 1718 if len(tensor.names) > len(longest_names): 1719 longest_names = tensor.names 1720 return [tensor.align_to(*longest_names) for tensor in tensors] 1721 1722 x = torch.empty(1, 1, names=('N', 'H')) 1723 y = torch.empty(2, 3, 5, names=('N', 'C', 'H')) 1724 z = torch.empty(2, names=('N',)) 1725 output = torch.align_tensors(x, y, z) 1726 expected_tensors = reference_fn(x, y, z) 1727 for tensor, expected in zip(output, expected_tensors): 1728 self.assertTensorDataAndNamesEqual(tensor, expected) 1729 1730 def test_mm(self): 1731 for device in get_all_device_types(): 1732 self._test_name_inference( 1733 torch.mm, device=device, 1734 args=(create('N:3,C:2'), create('W:2,H:5')), 1735 expected_names=('N', 'H')) 1736 1737 # left arg is unnamed 1738 self._test_name_inference( 1739 torch.mm, device=device, 1740 args=(create('3,2'), create('W:2,H:5')), 1741 expected_names=(None, 'H')) 1742 1743 # right arg is unnamed 1744 self._test_name_inference( 1745 torch.mm, device=device, 1746 args=(create('N:3,C:2'), create('2,5')), 1747 expected_names=('N', None)) 1748 1749 # out= 1750 self._test_name_inference( 1751 out_fn(torch.mm), device=device, 1752 args=(create('0'), create('N:3,C:2'), create('W:2,H:5')), 1753 expected_names=('N', 'H')) 1754 1755 self._test_name_inference( 1756 torch.mm, device=device, 1757 args=(create('N:3,C:2'), create('W:2,N:5')), 1758 maybe_raises_regex='with duplicate names') 1759 1760 def test_expand(self): 1761 for device in get_all_device_types(): 1762 self._test_name_inference( 1763 Tensor.expand, device=device, 1764 args=(create('D:1'), [3]), expected_names=('D',)) 1765 1766 self._test_name_inference( 1767 Tensor.expand, device=device, 1768 args=(create('H:3,W:2'), [10, 3, 3, 2]), 1769 expected_names=(None, None, 'H', 'W')) 1770 1771 self._test_name_inference( 1772 Tensor.expand, device=device, 1773 args=(create('3, 2'), [10, 3, 3, 2]), 1774 expected_names=(None, None, None, None)) 1775 1776 def test_addmm(self): 1777 for device in get_all_device_types(): 1778 # full names 1779 self._test_name_inference( 1780 torch.addmm, device=device, 1781 args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')), 1782 expected_names=('N', 'H')) 1783 1784 # no name on bias 1785 self._test_name_inference( 1786 torch.addmm, device=device, 1787 args=(create('3,5'), create('N:3,C:2'), create('W:2,H:5')), 1788 expected_names=('N', 'H')) 1789 1790 # partially named bias 1791 self._test_name_inference( 1792 torch.addmm, device=device, 1793 args=(create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')), 1794 expected_names=('N', 'H')) 1795 1796 # out= 1797 self._test_name_inference( 1798 out_fn(torch.addmm), device=device, 1799 args=(create('0'), create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')), 1800 expected_names=('N', 'H')) 1801 1802 # inplace 1803 self._test_name_inference( 1804 torch.Tensor.addmm_, device=device, 1805 args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')), 1806 expected_names=('N', 'H')) 1807 1808 self._test_name_inference( 1809 torch.addmm, device=device, 1810 args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,N:5')), 1811 maybe_raises_regex='with duplicate names') 1812 1813 def test_bmm(self): 1814 for device in get_all_device_types(): 1815 # full names 1816 self._test_name_inference( 1817 torch.bmm, device=device, 1818 args=(create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), 1819 expected_names=('N', 'A', 'B')) 1820 1821 # no name on left tensor 1822 self._test_name_inference( 1823 torch.bmm, device=device, 1824 args=(create('7,3,2'), create('N:7,A:2,B:5')), 1825 expected_names=('N', None, 'B')) 1826 1827 # no name on right tensor 1828 self._test_name_inference( 1829 torch.bmm, device=device, 1830 args=(create('N:7,A:3,B:2'), create('7,2,5')), 1831 expected_names=('N', 'A', None)) 1832 1833 # out= 1834 self._test_name_inference( 1835 out_fn(torch.bmm), device=device, 1836 args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), 1837 expected_names=('N', 'A', 'B')) 1838 1839 # duplicate names after mm 1840 self._test_name_inference( 1841 torch.bmm, device=device, 1842 args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')), 1843 maybe_raises_regex='with duplicate names') 1844 1845 # matching error (batch dimensions must be alignable) 1846 self._test_name_inference( 1847 torch.bmm, device=device, 1848 args=(create('N:3,A:3,B:3'), create('M:3,A:3,B:3')), 1849 maybe_raises_regex='do not match') 1850 1851 # misalignment (batch dimension is getting contracted) 1852 self._test_name_inference( 1853 torch.bmm, device=device, 1854 args=(create('N:3,A:3,B:3'), create('None:3,N:3,B:3')), 1855 maybe_raises_regex='misaligned') 1856 1857 def test_matmul(self): 1858 for device in get_all_device_types(): 1859 # input tensors are less than 1D 1860 self._test_name_inference( 1861 torch.matmul, device=device, 1862 args=(create(''), create('A:2')), 1863 maybe_raises_regex='at least 1D') 1864 self._test_name_inference( 1865 torch.matmul, device=device, 1866 args=(create('A:2'), create('')), 1867 maybe_raises_regex='at least 1D') 1868 1869 # 1D @ 1D 1870 self._test_name_inference( 1871 torch.matmul, device=device, 1872 args=(create('A:2'), create('B:2')), 1873 expected_names=[]) 1874 1875 # ND @ 1D 1876 self._test_name_inference( 1877 torch.matmul, device=device, 1878 args=(create('A:3,C:2'), create('B:2')), 1879 expected_names=['A']) 1880 self._test_name_inference( 1881 torch.matmul, device=device, 1882 args=(create('A:5,C:3,D:2'), create('B:2')), 1883 expected_names=['A', 'C']) 1884 1885 # 1D @ ND 1886 self._test_name_inference( 1887 torch.matmul, device=device, 1888 args=(create('C:2'), create('A:2,B:3')), 1889 expected_names=['B']) 1890 self._test_name_inference( 1891 torch.matmul, device=device, 1892 args=(create('C:2'), create('A:3,B:2,D:5')), 1893 expected_names=['A', 'D']) 1894 1895 # 2D @ 2D 1896 self._test_name_inference( 1897 torch.matmul, device=device, 1898 args=(create('A:3,B:2'), create('A:2,B:3')), 1899 expected_names=['A', 'B']) 1900 self._test_name_inference( 1901 torch.matmul, device=device, 1902 args=(create('A:3,B:2'), create('B:2,A:5')), 1903 maybe_raises_regex='with duplicate names') 1904 1905 # ND @ ND where N >= 2 1906 self._test_name_inference( 1907 torch.matmul, device=device, 1908 args=(create('C:5,A:3,B:2'), create('A:2,B:3')), 1909 expected_names=['C', 'A', 'B']) 1910 self._test_name_inference( 1911 torch.matmul, device=device, 1912 args=(create('C:5,A:3,B:2'), create('None:1,A:2,B:3')), 1913 expected_names=['C', 'A', 'B']) 1914 self._test_name_inference( 1915 torch.matmul, device=device, 1916 args=(create('C:5,A:3,B:2'), create('None:2,None:1,A:2,B:3')), 1917 expected_names=[None, 'C', 'A', 'B']) 1918 1919 # out= 1920 self._test_name_inference( 1921 out_fn(torch.matmul), device=device, 1922 args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), 1923 expected_names=('N', 'A', 'B')) 1924 1925 # duplicate names after mm 1926 self._test_name_inference( 1927 torch.bmm, device=device, 1928 args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')), 1929 maybe_raises_regex='with duplicate names') 1930 1931 # misalignment (batch dimension is getting contracted) 1932 self._test_name_inference( 1933 torch.matmul, device=device, 1934 args=(create('N:3,A:3,B:3'), create('A:3,N:3,B:3')), 1935 maybe_raises_regex='do not match') 1936 1937 def test_mv(self): 1938 for device in get_all_device_types(): 1939 self._test_name_inference( 1940 torch.mv, device=device, 1941 args=(create('N:3,C:2'), create('W:2')), 1942 expected_names=('N',)) 1943 1944 # left arg is unnamed 1945 self._test_name_inference( 1946 torch.mv, device=device, 1947 args=(create('3,2'), create('W:2')), 1948 expected_names=(None,)) 1949 1950 # right arg is unnamed 1951 self._test_name_inference( 1952 torch.mv, device=device, 1953 args=(create('N:3,C:2'), create('2')), 1954 expected_names=('N',)) 1955 1956 # out= 1957 self._test_name_inference( 1958 out_fn(torch.mv), device=device, 1959 args=(create('0'), create('N:3,C:2'), create('W:2')), 1960 expected_names=('N',)) 1961 1962 def test_addmv(self): 1963 for device in get_all_device_types(): 1964 # full names 1965 self._test_name_inference( 1966 torch.addmv, device=device, 1967 args=(create('N:3'), create('N:3,C:2'), create('H:2')), 1968 expected_names=['N']) 1969 1970 # no name on bias 1971 self._test_name_inference( 1972 torch.addmv, device=device, 1973 args=(create('3'), create('N:3,C:2'), create('H:2')), 1974 expected_names=('N',)) 1975 1976 # out= 1977 self._test_name_inference( 1978 out_fn(torch.addmv), device=device, 1979 args=(create('0'), create('N:3'), create('N:3,C:2'), create('H:2')), 1980 expected_names=('N',)) 1981 1982 # inplace 1983 self._test_name_inference( 1984 torch.Tensor.addmv_, device=device, 1985 args=(create('N:3'), create('N:3,C:2'), create('H:2')), 1986 expected_names=('N',)) 1987 1988 def test_autograd_ignores_names(self): 1989 # sigmoid forward is supported by named tensors, but sigmoid_backward 1990 # is not (see native_functions.yaml). Test that autograd ignores names 1991 # and that the sigmoid_backward succeeds. 1992 x = torch.randn(3, 3, names=('N', 'C'), requires_grad=True) 1993 x.sigmoid().sum().backward() 1994 1995 def test_tensor_grad_is_unnamed(self): 1996 x = torch.randn(3, 3, names=(None, None), requires_grad=True) 1997 y = torch.randn(3, 3, names=('N', 'C'), requires_grad=True) 1998 (x * y).sum().backward() 1999 2000 # Check that names weren't propagated 2001 self.assertEqual(y.grad.names, [None, None]) 2002 self.assertEqual(x.grad.names, [None, None]) 2003 2004 def test_autograd_warns_named_grad(self): 2005 base = torch.randn(3, 3, names=('N', 'C')) 2006 named_grad = base.clone() 2007 base.requires_grad_() 2008 2009 with warnings.catch_warnings(record=True) as warns: 2010 # Cause all warnings to always be triggered. 2011 warnings.simplefilter("always") 2012 base.clone().backward(named_grad) 2013 self.assertEqual(len(warns), 1) 2014 self.assertTrue( 2015 str(warns[0].message).startswith('Autograd was passed a named grad tensor')) 2016 2017 def test_nyi_dimname_overload_msg(self): 2018 x = torch.randn(3, 3) 2019 with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"): 2020 x.squeeze_("N") 2021 2022 def test_dot(self): 2023 for device in get_all_device_types(): 2024 # torch.dot ignores the names of both tensors 2025 self._test_name_inference( 2026 torch.dot, device=device, 2027 args=(create('C:2'), create('W:2')), 2028 expected_names=[]) 2029 2030 def test_comparison_ops(self): 2031 for device in get_all_device_types(): 2032 a = torch.randn(3, 3, names=('N', 'C'), device=device) 2033 b = torch.randn(3, 3, names=('N', 'C'), device=device) 2034 scalar = torch.randn([], device=device) 2035 2036 self.assertEqual((a == b).names, ['N', 'C']) 2037 self.assertEqual((a != b).names, ['N', 'C']) 2038 self.assertEqual((a > b).names, ['N', 'C']) 2039 self.assertEqual((a < b).names, ['N', 'C']) 2040 self.assertEqual((a >= b).names, ['N', 'C']) 2041 self.assertEqual((a <= b).names, ['N', 'C']) 2042 2043 self.assertEqual((a == 1).names, ['N', 'C']) 2044 self.assertEqual((a != 1).names, ['N', 'C']) 2045 self.assertEqual((a > 1).names, ['N', 'C']) 2046 self.assertEqual((a < 1).names, ['N', 'C']) 2047 self.assertEqual((a >= 1).names, ['N', 'C']) 2048 self.assertEqual((a <= 1).names, ['N', 'C']) 2049 2050 self.assertEqual((a == scalar).names, ['N', 'C']) 2051 self.assertEqual((a != scalar).names, ['N', 'C']) 2052 self.assertEqual((a > scalar).names, ['N', 'C']) 2053 self.assertEqual((a < scalar).names, ['N', 'C']) 2054 self.assertEqual((a >= scalar).names, ['N', 'C']) 2055 self.assertEqual((a <= scalar).names, ['N', 'C']) 2056 2057 res = torch.empty(3, 3, dtype=torch.bool, device=device) 2058 torch.eq(a, b, out=res) 2059 self.assertEqual(res.names, ['N', 'C']) 2060 torch.ne(a, b, out=res) 2061 self.assertEqual(res.names, ['N', 'C']) 2062 torch.lt(a, b, out=res) 2063 self.assertEqual(res.names, ['N', 'C']) 2064 torch.gt(a, b, out=res) 2065 self.assertEqual(res.names, ['N', 'C']) 2066 torch.le(a, b, out=res) 2067 self.assertEqual(res.names, ['N', 'C']) 2068 torch.ge(a, b, out=res) 2069 self.assertEqual(res.names, ['N', 'C']) 2070 2071 res = torch.isnan(a) 2072 self.assertEqual(res.names, ['N', 'C']) 2073 2074 res = torch.isinf(a) 2075 self.assertEqual(res.names, ['N', 'C']) 2076 2077 def test_support_device_named_grad(self): 2078 named_tensor = torch.randn(3, 3, device='meta') 2079 with self.assertRaisesRegex(RuntimeError, 'NYI: named tensors only support CPU, CUDA'): 2080 named_tensor.rename_('N', 'C') 2081 named_tensor.names = ['N', 'C'] 2082 named_tensor = torch.randn(3, 3, device='meta', names=['N', 'C']) 2083 2084 2085if __name__ == '__main__': 2086 run_tests() 2087