xref: /aosp_15_r20/external/pytorch/test/test_namedtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: named tensor"]
2
3import unittest
4from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY
5from torch.testing._internal.common_utils import skipIfTorchDynamo
6from torch.testing._internal.common_cuda import TEST_CUDA
7from torch.testing._internal.common_device_type import get_all_device_types
8from collections import namedtuple, OrderedDict
9import itertools
10import functools
11import torch
12from torch import Tensor
13import torch.nn.functional as F
14from multiprocessing.reduction import ForkingPickler
15import pickle
16import io
17import sys
18import warnings
19
20
21def pass_name_to_python_arg_parser(name):
22    x = torch.empty(2, names=(name,))
23
24
25def flatten(lst):
26    return [item for sublist in lst for item in sublist]
27
28
29Function = namedtuple('TestCase', ['name', 'lambd'])
30
31
32def parse_compressed_namedshape(string):
33    # This is a metalanguage for describing a shape of a tensor compactly.
34    # 'N:3,C:2' -> size = [3, 2], names: ['N', 'C']
35    # 'None:3,None:2' -> size = [3, 2], names: ['None', 'None']
36    # '3,2' -> size = [3, 2], names=None passed to ctor.
37    def parse_name(maybe_name):
38        maybe_name = maybe_name.strip()
39        if maybe_name == 'None':
40            return None
41        return maybe_name
42
43    string = string.strip()
44
45    # '' -> size: [], names:None
46    if len(string) == 0:
47        return None, []
48
49    # '3, 2' -> size = [3, 2], None names.
50    if ':' not in string:
51        return None, [int(size) for size in string.split(',')]
52
53    dims = string.split(',')
54    tuples = [dim.split(':') for dim in dims]
55    return zip(*[(parse_name(name), int(size)) for name, size in tuples])
56
57
58def create(namedshape, factory=torch.randn):
59    # namedshape: str
60    names, shape = parse_compressed_namedshape(namedshape)
61    return factory(shape, names=names)
62
63
64def out_fn(operator):
65    @functools.wraps(operator)
66    def fn(*inputs):
67        return operator(*inputs[1:], out=inputs[0])
68    return fn
69
70
71class TestNamedTensor(TestCase):
72    def test_aaa_must_run_first_check_experimental_warning(self):
73        # TODO(rzou): It would be nice for this to be a "real" python warning.
74        # Right now this error message only prints once and doesn't respect
75        # warnings.simplefilter behavior (where python users can control whether
76        # or not to display warnings once, all the time, or never).
77        with warnings.catch_warnings(record=True) as warns:
78            x = torch.randn(3, 3, names=('N', 'C'))
79            self.assertEqual(len(warns), 1)
80            self.assertTrue(str(warns[0].message).startswith(
81                'Named tensors and all their associated APIs are an experimental feature'))
82
83    def test_trivial(self):
84        pass
85
86    def _test_name_inference(self, op, args=(), expected_names=(), device='cpu',
87                             maybe_raises_regex=None):
88        casted_args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg
89                       for arg in args]
90        if maybe_raises_regex is not None:
91            with self.assertRaisesRegex(RuntimeError, maybe_raises_regex):
92                result = op(*args)
93            return
94        result = op(*args)
95        self.assertEqual(result.names, expected_names,
96                         msg=f'Name inference for {op.__name__} on device {device} failed')
97
98    # TODO(rzou): Some form of this check should be added to self.assertEqual.
99    # Right now I don't know what it should look like.
100    def assertTensorDataAndNamesEqual(self, x, y):
101        self.assertEqual(x.names, y.names)
102        unnamed_x = x.rename(None)
103        unnamed_y = y.rename(None)
104        self.assertEqual(unnamed_x, unnamed_y)
105
106    def _test_factory(self, factory, device):
107        x = factory([], device=device)
108        self.assertEqual(x.names, ())
109
110        x = factory(1, 2, 3, device=device)
111        self.assertEqual(x.names, (None, None, None))
112
113        x = factory(1, 2, 3, names=None, device=device)
114        self.assertEqual(x.names, (None, None, None))
115
116        x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device)
117        self.assertEqual(x.names, ('N', 'T', 'D'))
118
119        x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
120        self.assertEqual(x.names, ('N', None, 'D'))
121
122        x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device)
123        self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5'))
124
125        with self.assertRaisesRegex(RuntimeError,
126                                    'a valid identifier contains only'):
127            x = factory(2, names=('1',), device=device)
128
129        with self.assertRaisesRegex(RuntimeError,
130                                    'a valid identifier contains only'):
131            x = factory(2, names=('?',), device=device)
132
133        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
134            x = factory(2, 1, names=('N',), device=device)
135
136        with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'):
137            x = factory(2, 1, names='N', device=device)
138
139        with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'):
140            x = factory(2, 1, 1, names=('N', 'C', 'N'), device=device)
141
142        names64 = ['A' * i for i in range(1, 65)]
143        x = factory([1] * 64, names=names64, device=device)
144        self.assertEqual(x.names, names64)
145
146        with self.assertRaisesRegex(
147                RuntimeError,
148                'only support up to 64 dims'):
149            names65 = ['A' * i for i in range(1, 66)]
150            x = factory([1] * 65, names=names64, device=device)
151
152    @skipIfTorchDynamo("not a bug: Dynamo causes the refcounts to be different")
153    def test_none_names_refcount(self, N=10):
154        def scope():
155            unnamed = torch.empty(2, 3)
156            unnamed.names  # materialize [None, None]
157
158        prev_none_refcnt = sys.getrefcount(None)
159        # Ran it N times to reduce flakiness
160        [scope() for i in range(N)]
161        after_none_refcnt = sys.getrefcount(None)
162        self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2,
163                        msg='Using tensor.names should not change '
164                            'the refcount of Py_None')
165
166    def test_has_names(self):
167        unnamed = torch.empty(2, 3)
168        none_named = torch.empty(2, 3, names=(None, None))
169        partially_named = torch.empty(2, 3, names=('N', None))
170        fully_named = torch.empty(2, 3, names=('N', 'C'))
171
172        self.assertFalse(unnamed.has_names())
173        self.assertFalse(none_named.has_names())
174        self.assertTrue(partially_named.has_names())
175        self.assertTrue(fully_named.has_names())
176
177    def test_py3_ellipsis(self):
178        tensor = torch.randn(2, 3, 5, 7)
179        output = tensor.refine_names('N', ..., 'C')
180        self.assertEqual(output.names, ['N', None, None, 'C'])
181
182    def test_refine_names(self):
183        # Unnamed tensor -> Unnamed tensor
184        self._test_name_inference(Tensor.refine_names,
185                                  [create('None:1,None:2,None:3'), 'N', 'C', 'H'],
186                                  ['N', 'C', 'H'])
187
188        # Named tensor -> Named tensor
189        self._test_name_inference(Tensor.refine_names,
190                                  [create('N:1,C:2,H:3'), 'N', 'C', 'H'],
191                                  ['N', 'C', 'H'])
192
193        # Partially named tensor -> named tensor
194        self._test_name_inference(Tensor.refine_names,
195                                  [create('None:1,C:2,None:3'), None, 'C', 'H'],
196                                  [None, 'C', 'H'])
197
198        # Too few names
199        self._test_name_inference(Tensor.refine_names,
200                                  [create('None:2,None:3'), 'N', 'C', 'H'],
201                                  maybe_raises_regex="different number of dims")
202
203        # Cannot change Tensor[D] to Tensor[N]
204        self._test_name_inference(Tensor.refine_names,
205                                  [create('D:3'), 'N'],
206                                  maybe_raises_regex="is different from")
207
208        # Cannot change Tensor[D] to Tensor[None]
209        self._test_name_inference(Tensor.refine_names,
210                                  [create('D:3'), None],
211                                  maybe_raises_regex="'D' is more specific than None")
212
213        # globbing behavior exists
214        self._test_name_inference(Tensor.refine_names,
215                                  [create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'],
216                                  [None, None, 'C', 'H'])
217
218    def test_detach(self):
219        names = ['N']
220        self._test_name_inference(
221            Tensor.detach_,
222            [torch.randn(3, requires_grad=True, names=names)],
223            names)
224        self._test_name_inference(
225            Tensor.detach,
226            [torch.randn(3, requires_grad=True, names=names)],
227            names)
228
229    def test_index_fill(self):
230        for device in get_all_device_types():
231            expected_names = ('N', 'C')
232            x = torch.randn(3, 5, device=device, names=expected_names)
233
234            output = x.index_fill_('C', torch.tensor([0, 1], device=device), 5)
235            self.assertEqual(output.names, expected_names)
236
237            output = x.index_fill_('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
238            self.assertEqual(output.names, expected_names)
239
240            output = x.index_fill('C', torch.tensor([0, 1], device=device), 5)
241            self.assertEqual(output.names, expected_names)
242
243            output = x.index_fill('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
244            self.assertEqual(output.names, expected_names)
245
246    def test_equal(self):
247        for device in get_all_device_types():
248            tensor = torch.randn(2, 3, device=device)
249            other = tensor.clone()
250
251            self.assertTrue(torch.equal(tensor.rename('N', 'C'), other.rename('N', 'C')))
252            self.assertFalse(torch.equal(tensor.rename('M', 'C'), other.rename('N', 'C')))
253            self.assertFalse(torch.equal(tensor.rename(None, 'C'), other.rename('N', 'C')))
254
255    def test_squeeze(self):
256        x = create('N:3,C:1,H:1,W:1')
257        output = x.squeeze('C')
258        self.assertEqual(output.names, ['N', 'H', 'W'])
259
260        output = x.squeeze()
261        self.assertEqual(output.names, ['N'])
262
263    def test_repr(self):
264        named_tensor = torch.zeros(2, 3).rename_('N', 'C')
265        expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]], names=('N', 'C'))"
266        self.assertEqual(repr(named_tensor), expected)
267
268        unnamed_tensor = torch.zeros(2, 3)
269        expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]])"
270        self.assertEqual(repr(unnamed_tensor), expected)
271
272        none_named_tensor = torch.zeros(2, 3).rename_(None, None)
273        self.assertEqual(repr(none_named_tensor), expected)
274
275    def test_diagonal(self):
276        named_tensor = torch.zeros(2, 3, 5, 7, names=list('ABCD'))
277        self.assertEqual(named_tensor.diagonal().names, ['C', 'D', None])
278        self.assertEqual(named_tensor.diagonal(1, 3).names, ['A', 'C', None])
279
280        self.assertEqual(named_tensor.diagonal(outdim='E', dim1='B', dim2='D').names,
281                         ['A', 'C', 'E'])
282
283    def test_max_pooling(self):
284        def check_tuple_return(op, inputs, expected_names):
285            values, indices = op(*inputs)
286            self.assertEqual(values.names, expected_names)
287            self.assertEqual(indices.names, expected_names)
288
289        for device in get_all_device_types():
290
291            named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list('ABC'))
292            named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list('ABCD'))
293            named_tensor_3d = torch.zeros(2, 3, 5, 7, 9, device=device, names=list('ABCDE'))
294
295            self.assertEqual(F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names)
296            self.assertEqual(F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names)
297            self.assertEqual(F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names)
298
299            check_tuple_return(F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names)
300            check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names)
301            check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names)
302
303    def test_max_pooling_without_names_does_not_warn(self):
304        for device in get_all_device_types():
305            tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True)
306            with warnings.catch_warnings(record=True) as warns:
307                warnings.simplefilter("always")
308                result = F.max_pool2d(tensor_2d, [2, 2])
309                result.sum().backward()
310                self.assertEqual(len(warns), 0)
311
312    def test_no_save_support(self):
313        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
314        buf = io.BytesIO()
315        with self.assertRaisesRegex(RuntimeError, "NYI"):
316            torch.save(named_tensor, buf)
317
318    def test_no_pickle_support(self):
319        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
320        with self.assertRaisesRegex(RuntimeError, "NYI"):
321            serialized = pickle.dumps(named_tensor)
322
323    def test_no_multiprocessing_support(self):
324        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
325        buf = io.BytesIO()
326        with self.assertRaisesRegex(RuntimeError, "NYI"):
327            ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor)
328
329    def test_big_tensor_repr_has_names(self):
330        def check_repr(named_tensor):
331            unnamed_tensor = named_tensor.rename(None)
332            names_tag = f'names={named_tensor.names}'
333            self.assertIn(names_tag, repr(named_tensor))
334
335        check_repr(torch.randn(128, 3, 64, 64, names=('N', 'C', 'H', 'W')))
336
337    def test_noncontig_contiguous(self):
338        # This type of contiguous is special-cased and therefore needs its own test
339        for device in get_all_device_types():
340            x = torch.randn(2, 3, device=device).t().rename_('N', 'C')
341            self.assertEqual(x.contiguous().names, ('N', 'C'))
342
343    def test_copy_transpose(self):
344        # This type of copy is special-cased and therefore needs its own test
345        def _test(self_names, other_names, expected_names):
346            x = torch.empty(2, 5, names=self_names)
347            y = torch.empty(5, 2).t().rename_(*other_names)
348            x.copy_(y)
349            self.assertEqual(x.names, expected_names)
350
351        _test(('N', 'C'), ('N', 'C'), ('N', 'C'))
352        _test(None, ('N', 'C'), ('N', 'C'))
353
354    def test_rename_(self):
355        tensor = torch.empty(1, 1, names=('N', 'C'))
356        self.assertEqual(tensor.rename_(None).names, (None, None))
357        self.assertEqual(tensor.rename_('H', 'W').names, ('H', 'W'))
358        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
359            tensor.rename_('N', 'C', 'W')
360        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
361            tensor.rename_('N', 'N')
362
363    def test_rename(self):
364        tensor = torch.empty(1, 1, names=('N', 'C'))
365
366        self.assertEqual(tensor.rename(None).names, (None, None))
367        self.assertEqual(tensor.rename('H', 'W').names, ('H', 'W'))
368
369        # Check that we didn't modify tensor.names
370        self.assertEqual(tensor.names, ('N', 'C'))
371
372        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
373            tensor.rename('N', 'C', 'W')
374        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
375            tensor.rename('N', 'N')
376
377        with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'):
378            tensor.rename(None, N='batch')
379
380        # rename returns a view on the tensor
381        self.assertEqual(tensor.rename('H', 'W').data_ptr(), tensor.data_ptr())
382        self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr())
383
384    def test_rename_globber(self):
385        scalar = torch.randn([])
386        unnamed_tensor = torch.empty(1, 1, 1, 1)
387        named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))
388
389        self.assertEqual(scalar.rename(None).names, [])
390        self.assertEqual(scalar.rename('...').names, [])
391
392        # Check that it works with unnamed tensors
393        self.assertEqual(unnamed_tensor.rename('...').names, unnamed_tensor.names)
394        self.assertEqual(unnamed_tensor.rename('...', 'H', 'W').names,
395                         [None, None, 'H', 'W'])
396        self.assertEqual(unnamed_tensor.rename('N', '...', 'W').names,
397                         ['N', None, None, 'W'])
398        self.assertEqual(unnamed_tensor.rename('N', 'C', '...').names,
399                         ['N', 'C', None, None])
400
401        # Check that it works with named tensors
402        self.assertEqual(named_tensor.rename('...').names, named_tensor.names)
403        self.assertEqual(named_tensor.rename('...', 'width').names,
404                         ['N', 'C', 'H', 'width'])
405        self.assertEqual(named_tensor.rename('batch', 'channels', '...', 'width').names,
406                         ['batch', 'channels', 'H', 'width'])
407        self.assertEqual(named_tensor.rename('batch', '...').names,
408                         ['batch', 'C', 'H', 'W'])
409
410        # Test empty glob
411        self.assertEqual(unnamed_tensor.rename('...', None, None, None, None).names,
412                         [None, None, None, None])
413        self.assertEqual(named_tensor.rename('N', 'C', 'H', '...', 'W').names,
414                         ['N', 'C', 'H', 'W'])
415
416        # Multiple globs throw
417        with self.assertRaisesRegex(RuntimeError, 'More than one '):
418            named_tensor.rename('...', 'channels', '...')
419
420    def test_rename_rename_map(self):
421        scalar = torch.randn([])
422        unnamed_tensor = torch.empty(1, 1, 1, 1)
423        named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))
424
425        with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
426            scalar.rename(N='batch')
427        with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
428            unnamed_tensor.rename(N='batch')
429        with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
430            named_tensor.rename(B='batch')
431        with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
432            named_tensor.rename(H='height', B='batch')
433
434        self.assertEqual(named_tensor.rename(N='batch').data_ptr(),
435                         named_tensor.data_ptr())
436        self.assertEqual(named_tensor.rename(N='batch').names,
437                         ['batch', 'C', 'H', 'W'])
438        self.assertEqual(named_tensor.rename(N='batch', H='height').names,
439                         ['batch', 'C', 'height', 'W'])
440
441    def test_set_names_property(self):
442        tensor = torch.empty(1, 1, names=('N', 'C'))
443
444        tensor.names = None
445        self.assertEqual(tensor.names, (None, None))
446
447        tensor.names = ('N', 'W')
448        self.assertEqual(tensor.names, ('N', 'W'))
449
450        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
451            tensor.names = ['N', 'C', 'W']
452        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
453            tensor.names = ['N', 'N']
454
455    def test_factory_edge_cases(self):
456        for device in get_all_device_types():
457            self._test_factory(torch.empty, device)
458
459    def test_factory_coverage(self):
460        def _test(factory, device):
461            names = ('N', 'T', 'D')
462
463            torch.manual_seed(0)
464            result = factory(1, 2, 3, names=names, device=device)
465
466            torch.manual_seed(0)
467            expected = factory(1, 2, 3, device=device).rename_(*names)
468
469            self.assertTensorDataAndNamesEqual(result, expected)
470
471        supported = [
472            torch.ones,
473            torch.rand,
474            torch.randn,
475            torch.zeros,
476        ]
477
478        for op, device in itertools.product(supported, get_all_device_types()):
479            _test(op, device)
480
481        # Test torch.full
482        for device in get_all_device_types():
483            names = ('N', 'T', 'D')
484            result = torch.full([1, 2, 3], 2., names=names, device=device)
485            expected = torch.full([1, 2, 3], 2., device=device).rename_(*names)
486            self.assertTensorDataAndNamesEqual(result, expected)
487
488    def test_tensor_from_lists(self):
489        names = ('N', 'C')
490        tensor = torch.tensor([[1]], names=names)
491        self.assertEqual(tensor.names, names)
492
493        names = ('N',)
494        tensor = torch.tensor([1], names=names)
495        self.assertEqual(tensor.names, names)
496
497        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
498            names = ('N', 'C')
499            tensor = torch.tensor([1], names=names)
500
501    @unittest.skipIf(not TEST_NUMPY, "no numpy")
502    def test_tensor_from_numpy(self):
503        import numpy as np
504        arr = np.array([[1]])
505        names = ('N', 'C')
506        tensor = torch.tensor([[1]], names=names)
507        self.assertEqual(tensor.names, names)
508
509    def test_tensor_from_tensor(self):
510        x = torch.randn(1, 1)
511        names = ('N', 'C')
512        tensor = torch.tensor(x, names=names)
513        self.assertEqual(tensor.names, names)
514
515    def test_tensor_from_named_tensor(self):
516        x = torch.randn(1, 1, names=('N', 'D'))
517        tensor = torch.tensor(x)
518        self.assertEqual(tensor.names, ('N', 'D'))
519
520        # there's no way to distinguish between names=None and not passing in names.
521        # If the user passes in names=None they are asking for trouble.
522        x = torch.randn(1, 1, names=('N', 'D'))
523        tensor = torch.tensor(x, names=None)
524        self.assertEqual(tensor.names, ('N', 'D'))
525
526        x = torch.randn(1, 1, names=('N', 'D'))
527        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
528            tensor = torch.tensor(x, names=('N', 'C'))
529
530    def test_size(self):
531        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
532        self.assertEqual(t.size('N'), 2)
533        self.assertEqual(t.size('C'), 5)
534        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
535            t.size('channels')
536        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
537            torch.empty(2, 3, 4).size('N')
538
539    def test_stride(self):
540        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
541        self.assertEqual(t.stride('N'), 3 * 5)
542        self.assertEqual(t.stride('C'), 1)
543        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
544            t.stride('channels')
545        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
546            torch.empty(2, 3, 4).stride('N')
547
548    def test_transpose_variants(self):
549        t = torch.randn(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
550        self.assertEqual(t.transpose('N', 'C').names, ['C', 'N', 'H', 'W'])
551        self.assertEqual(t.transpose(1, 3).names, ['N', 'W', 'H', 'C'])
552
553        t = torch.randn(2, 3, names=('N', 'C'))
554        self.assertEqual(t.t().names, ['C', 'N'])
555
556    def test_resize(self):
557        for device in get_all_device_types():
558            named = torch.randn(2, names=('N',), device=device)
559            named.resize_([2])
560            self.assertEqual(named.names, ['N'])
561
562            with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"):
563                named.resize_([3])
564
565            other_named = torch.randn(2, names=('N',), device=device)
566            named.resize_as_(other_named)
567            self.assertEqual(other_named.names, ['N'])
568
569            unnamed = torch.randn(2, device=device)
570            with self.assertRaisesRegex(
571                    RuntimeError, r'names .* are not the same as the computed output names'):
572                named.resize_as_(unnamed)
573
574            unnamed = torch.randn(1, device=device)
575            unnamed.resize_as_(named)
576            self.assertEqual(unnamed.names, ['N'])
577
578    def test_cdist(self):
579        for device in get_all_device_types():
580            tensor = torch.randn(3, 1, 2, 7, names=('M', 'N', 'first_group', 'features'),
581                                 device=device)
582            other = torch.randn(5, 11, 7, names=('N', 'second_group', 'features'),
583                                device=device)
584            result = torch.cdist(tensor, other)
585            self.assertEqual(result.names, ['M', 'N', 'first_group', 'second_group'])
586
587    def test_info_smoke(self):
588        # Smoke test for info functions / methods / attributes on named tensors.
589        tensor = torch.empty(1, 1, names=('N', 'D'))
590
591        tensor.device
592        tensor.dtype
593        tensor.get_device()
594        tensor.is_complex()
595        tensor.is_floating_point()
596        tensor.is_nonzero()
597        torch.is_same_size(tensor, tensor)
598        torch.is_signed(tensor)
599        tensor.layout
600        tensor.numel()
601        tensor.dim()
602        tensor.element_size()
603        tensor.is_contiguous()
604        tensor.is_cuda
605        tensor.is_leaf
606        tensor.is_pinned()
607        tensor.is_shared()
608        tensor.is_sparse
609        tensor.ndimension()
610        tensor.nelement()
611        tensor.shape
612        tensor.size()
613        tensor.size(1)
614        tensor.storage()
615        tensor.storage_offset()
616        tensor.storage_type()
617        tensor.stride()
618        tensor.stride(1)
619        tensor.data
620        tensor.data_ptr()
621        tensor.ndim
622        tensor.item()
623        tensor.type()
624        tensor.is_shared()
625        tensor.is_signed()
626
627    def test_autograd_smoke(self):
628        x = torch.randn(3, 3, names=('N', 'D'), requires_grad=True)
629
630        y = x.clone()
631        y.retain_grad()
632        y.register_hook(lambda x: x)
633
634        y.sum().backward()
635
636        # autograd related attributes
637        tensor = torch.empty(1, 1, names=('N', 'D'), requires_grad=True)
638        tensor = tensor.relu()
639        tensor.output_nr
640        tensor.grad_fn
641        tensor.requires_grad
642
643    def test_split_fns_propagates_names(self):
644        fns = [
645            lambda x: x.split(1, 0),
646            lambda x: x.split([1, 1], 1),
647            lambda x: x.chunk(2, 0),
648        ]
649
650        for device in get_all_device_types():
651            orig_tensor = torch.empty(2, 2, names=('N', 'D'), device=device)
652            for fn in fns:
653                splits = fn(orig_tensor)
654                for split in splits:
655                    self.assertEqual(split.names, orig_tensor.names)
656
657    def test_any_all(self):
658        for device in get_all_device_types():
659            x = torch.zeros(3, dtype=torch.bool, device=device, names=('C',))
660            self.assertEqual(x.any().names, [])
661            self.assertEqual(x.all().names, [])
662
663    def test_addcmul_addcdiv(self):
664        for device in get_all_device_types():
665            names = ['N']
666            a = torch.rand(3, device=device, names=names)
667            b = torch.rand(3, device=device, names=names)
668            # avoid division by 0
669            c = torch.rand(3, device=device, names=names).clamp_min_(0.1)
670            out = torch.randn(3, device=device, names=names)
671
672            self.assertEqual(torch.addcmul(a, b, c).names, names)
673            self.assertEqual(torch.addcmul(a, b, c, out=out).names, names)
674            self.assertEqual(a.addcmul_(b, c).names, names)
675
676            self.assertEqual(torch.addcdiv(a, b, c).names, names)
677            self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names)
678            self.assertEqual(a.addcdiv_(b, c).names, names)
679
680    def test_binary_ops(self):
681        def test_basic(op):
682            a = torch.empty(2, 3, names=('N', 'C'))
683            b = torch.empty(3, 2, names=('C', 'N'))
684            c = torch.empty(3, names=('C',))
685            d = torch.empty(5, names=('W',))
686
687            self.assertEqual(op(a, a).names, ('N', 'C'))
688            self.assertEqual(op(a, c).names, ('N', 'C'))
689            # TODO: dynamo will throw a slightly different
690            # error message because it's adding fake tensors
691            # `must match the size of` portion is the dynamo error
692            with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
693                op(a, d)
694            with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
695                op(a, b)
696
697        def test_wildcard(op):
698            a = torch.empty(2, 3, names=('N', 'C'))
699            c = torch.empty(2, 3, names=(None, 'C'))
700            self.assertEqual(op(a, c).names, ('N', 'C'))
701
702            b = torch.empty(2, 3)
703            self.assertEqual(op(a, b).names, ('N', 'C'))
704
705            d = torch.empty(2, 3, names=('C', None))
706            with self.assertRaisesRegex(RuntimeError, "Misaligned"):
707                op(d, c)
708
709        def test_mixed_unnamed_named(op, is_inplace):
710            named2 = torch.randn(1, 1, names=('N', 'C'))
711            unnamed1 = torch.randn(1)
712            unnamed2 = torch.randn(1, 1)
713            unnamed3 = torch.randn(1, 1, 1)
714
715            def compute_expected_names(tensor, other):
716                assert tensor.has_names() ^ other.has_names()
717                named = tensor if tensor.has_names() else other
718                unnamed = other if tensor.has_names() else tensor
719                unnamed_dim = unnamed.dim()
720                if unnamed_dim > named.dim():
721                    return [None] * (unnamed_dim - named.dim()) + list(named.names)
722                else:
723                    return named.names
724
725            inputs = itertools.chain(
726                itertools.product([named2], [unnamed1, unnamed2, unnamed3]),
727                itertools.product([unnamed1, unnamed2, unnamed3], [named2]),
728            )
729            if is_inplace:
730                # In-place ops have the constraint that they must not change shape.
731                inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()]
732
733            for tensor, other in inputs:
734                expected_names = compute_expected_names(tensor, other)
735                self.assertEqual(op(tensor, other).names, expected_names)
736
737        def method(name, *args, **kwargs):
738            return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))]
739
740        def function(name, *args, **kwargs):
741            return [Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))]
742
743        def out_function(name, *args, **kwargs):
744            out_fn = getattr(torch, name)
745
746            def fn(a, b):
747                result = torch.empty([0], dtype=a.dtype, device=a.device)
748                out_fn(a, b, *args, out=result, **kwargs)
749                return result
750
751            return [Function(name, fn)]
752
753        def fn_method_and_inplace(name, *args, **kwargs):
754            return (
755                method(name, *args, **kwargs) +
756                method(name + '_', *args, **kwargs) +
757                out_function(name, *args, **kwargs)
758            )
759
760        tests = [
761            fn_method_and_inplace('add'),
762            fn_method_and_inplace('div'),
763            fn_method_and_inplace('mul'),
764            fn_method_and_inplace('sub'),
765            fn_method_and_inplace('pow'),
766            fn_method_and_inplace('atan2'),
767            method('copy_'),
768            function('floor_divide'),
769            function('true_divide'),
770        ]
771        tests = flatten(tests)
772
773        for name, op in tests:
774            test_basic(op)
775            test_wildcard(op)
776            test_mixed_unnamed_named(op, is_inplace=name.endswith('_'))
777
778    def test_logical_ops(self):
779        # Implemented via TensorIterator, so just check that each version
780        # (out-of-place, inplace, out=) propagates names.
781        def zeros(*args, **kwargs):
782            return torch.zeros(*args, dtype=torch.bool, **kwargs)
783
784        for op in ('logical_xor', 'logical_and', 'logical_or'):
785            self._test_name_inference(
786                getattr(torch, op),
787                (create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
788                expected_names=['N', 'C'])
789
790            self._test_name_inference(
791                getattr(Tensor, op + '_'),
792                (create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
793                expected_names=['N', 'C'])
794
795            self._test_name_inference(
796                lambda out, x, y: getattr(torch, op)(x, y, out=out),
797                (create('0', zeros), create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
798                expected_names=['N', 'C'])
799
800    def test_pow_special(self):
801        # There are a few pow cases that don't go through TensorIterator.
802        # Test them here.
803        for device in get_all_device_types():
804            named = torch.randn(2, 3, names=('N', 'C'), device=device)
805            unnamed = torch.randn([0], device=device)
806
807            result = torch.pow(named, 0, out=unnamed.clone())
808            self.assertEqual(result.names, named.names)
809
810            result = torch.pow(named, 1, out=unnamed.clone())
811            self.assertEqual(result.names, named.names)
812
813            result = torch.pow(1, named, out=unnamed.clone())
814            self.assertEqual(result.names, named.names)
815
816    def test_out_fn_semantics(self):
817        out_fn = torch.abs
818        unnamed_tensor = torch.randn(3, 2)
819        none_named_tensor = torch.randn(3, 2, names=(None, None))
820        named_tensor = torch.randn(3, 2, names=('N', 'C'))
821        partially_named_tensor = torch.randn(3, 2, names=('N', None))
822
823        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
824            out_fn(partially_named_tensor, out=named_tensor)
825        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
826            out_fn(named_tensor, out=partially_named_tensor)
827        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
828            out_fn(none_named_tensor, out=named_tensor)
829        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
830            out_fn(unnamed_tensor, out=named_tensor)
831
832        output = torch.randn(3, 2)
833        out_fn(unnamed_tensor, out=output)
834        self.assertFalse(output.has_names())
835
836        output = torch.randn(3, 2, names=(None, None))
837        out_fn(named_tensor, out=output)
838        self.assertEqual(output.names, named_tensor.names)
839
840        output = torch.randn(3, 2)
841        out_fn(named_tensor, out=output)
842        self.assertEqual(output.names, named_tensor.names)
843
844        output = torch.randn(3, 2, names=(None, None))
845        out_fn(unnamed_tensor, out=output)
846        self.assertFalse(output.has_names())
847
848    def test_unary_propagate_names_fns(self):
849        def _test(testcase, names=('N', 'D'), device='cpu'):
850            sizes = [2] * len(names)
851            tensor = torch.empty(sizes, names=names, device=device)
852            try:
853                out = testcase.lambd(tensor)
854            except RuntimeError as err:
855                # Get a better error message by catching the error and asserting.
856                raise RuntimeError(f'{testcase.name}: {err}') from err
857            self.assertEqual(out.names, tensor.names,
858                             msg=testcase.name)
859
860        def fn(name, *args, **kwargs):
861            return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]
862
863        def method(name, *args, **kwargs):
864            return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))]
865
866        def out_function(name, *args, **kwargs):
867            out_fn = getattr(torch, name)
868
869            def fn(tensor):
870                result = torch.empty([0], dtype=tensor.dtype, device=tensor.device)
871                out_fn(tensor, *args, out=result, **kwargs)
872                return result
873
874            return [Function(name + '_out', fn)]
875
876        def fn_method_and_inplace(name, *args, **kwargs):
877            return (
878                method(name, *args, **kwargs) +
879                method(name + '_', *args, **kwargs) +
880                out_function(name, *args, **kwargs)
881            )
882
883        # All of these operate on 2x2 tensors.
884        tests = [
885            # unary pointwise
886            fn_method_and_inplace('abs'),
887            fn_method_and_inplace('acos'),
888            fn_method_and_inplace('asin'),
889            fn_method_and_inplace('atan'),
890            fn_method_and_inplace('ceil'),
891            fn_method_and_inplace('clamp', -1, 1),
892            fn_method_and_inplace('clamp_min', -2),
893            fn_method_and_inplace('clamp_max', 2),
894            method('cauchy_'),
895            method('clone'),
896            method('contiguous'),
897            fn_method_and_inplace('cos'),
898            fn_method_and_inplace('cosh'),
899            fn_method_and_inplace('digamma'),
900            fn_method_and_inplace('erf'),
901            fn_method_and_inplace('erfc'),
902            fn_method_and_inplace('erfinv'),
903            fn_method_and_inplace('exp'),
904            fn_method_and_inplace('expm1'),
905            method('exponential_'),
906            fn_method_and_inplace('floor'),
907            fn_method_and_inplace('frac'),
908            method('geometric_', p=0.5),
909            fn_method_and_inplace('lgamma'),
910            fn_method_and_inplace('log'),
911            fn_method_and_inplace('log10'),
912            fn_method_and_inplace('log1p'),
913            fn_method_and_inplace('log2'),
914            method('log_normal_'),
915            fn_method_and_inplace('neg'),
916            method('normal_'),
917            [Function('polygamma', lambda t: torch.polygamma(1, t))],
918            method('polygamma_', 1),
919            fn_method_and_inplace('reciprocal'),
920            method('random_', 0, 1),
921            method('random_', 1),
922            method('random_'),
923            method('relu_'),
924            method('requires_grad_'),
925            method('relu'),
926            fn_method_and_inplace('round'),
927            fn_method_and_inplace('rsqrt'),
928            fn_method_and_inplace('sigmoid'),
929            fn_method_and_inplace('sign'),
930            fn_method_and_inplace('sin'),
931            fn_method_and_inplace('sinh'),
932            fn_method_and_inplace('sqrt'),
933            fn_method_and_inplace('tan'),
934            fn_method_and_inplace('tanh'),
935            fn('threshold', 0, 1),
936            fn('threshold_', 0, 1),
937            out_function('threshold', 0, 1),
938            fn_method_and_inplace('trunc'),
939            method('uniform_'),
940            method('zero_'),
941            method('fill_', 1),
942            method('fill_', torch.tensor(3.14)),
943
944            # conversions
945            method('to', dtype=torch.long),
946            method('to', device='cpu'),
947            method('to', torch.empty([])),
948            method('bool'),
949            method('byte'),
950            method('char'),
951            method('cpu'),
952            method('double'),
953            method('float'),
954            method('long'),
955            method('half'),
956            method('int'),
957            method('short'),
958            method('type', dtype=torch.long),
959
960            # cumsum and cumprod
961            fn('cumsum', 0),
962            fn('cumsum', 'D'),
963            out_function('cumsum', 'D'),
964            fn('cumprod', 0),
965            fn('cumprod', 'D'),
966            out_function('cumprod', 'D'),
967
968            # views
969            method('narrow', 0, 0, 1),
970
971            # creation functions
972            fn('empty_like'),
973            fn('zeros_like'),
974            fn('ones_like'),
975            fn('full_like', 3.14),
976            fn('rand_like'),
977            fn('randn_like'),
978
979            # bernoulli variants
980            method('bernoulli_', 0.5),
981            method('bernoulli_', torch.tensor(0.5)),
982
983            method('softmax', dim=1),
984            method('softmax', dim='D'),
985            method('log_softmax', dim=1),
986            method('log_softmax', dim='D'),
987
988            [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))],
989            [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))],
990        ]
991        tests = flatten(tests)
992
993        for testcase, device in itertools.product(tests, get_all_device_types()):
994            _test(testcase, device=device)
995
996    def test_cummax_cummin(self):
997        def test_ops(op):
998            for device in get_all_device_types():
999                names = ('N', 'D')
1000                tensor = torch.rand(2, 3, names=names)
1001                result = op(tensor, 0)
1002                self.assertEqual(result[0].names, names)
1003                self.assertEqual(result[1].names, names)
1004        test_ops(torch.cummax)
1005        test_ops(torch.cummin)
1006
1007    def test_logcumsumexp(self):
1008        for device in get_all_device_types():
1009            names = ('N', 'D')
1010            tensor = torch.rand(2, 3, names=names)
1011            result = torch.logcumsumexp(tensor, 'D')
1012            self.assertEqual(result.names, names)
1013
1014    def test_bitwise_not(self):
1015        for device in get_all_device_types():
1016            names = ('N', 'D')
1017            tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
1018            result = torch.empty(0, dtype=torch.bool)
1019
1020            self.assertEqual(tensor.bitwise_not().names, names)
1021            self.assertEqual(torch.bitwise_not(tensor, out=result).names, names)
1022            self.assertEqual(tensor.bitwise_not_().names, names)
1023
1024    def test_logical_not(self):
1025        for device in get_all_device_types():
1026            names = ('N', 'D')
1027            tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
1028            result = torch.empty(0, dtype=torch.bool)
1029
1030            self.assertEqual(tensor.logical_not().names, names)
1031            self.assertEqual(torch.logical_not(tensor, out=result).names, names)
1032            self.assertEqual(tensor.logical_not_().names, names)
1033
1034    def test_bernoulli(self):
1035        for device in get_all_device_types():
1036            names = ('N', 'D')
1037            tensor = torch.rand(2, 3, names=names)
1038            result = torch.empty(0)
1039            self.assertEqual(tensor.bernoulli().names, names)
1040
1041            torch.bernoulli(tensor, out=result)
1042            self.assertEqual(result.names, names)
1043
1044    def test_flatten(self):
1045        tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W'))
1046
1047        # basic
1048        out = tensor.flatten('D', 'W', 'features')
1049        self.assertEqual(out.names, ['N', 'C', 'features'])
1050        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
1051
1052        # int overload
1053        out = tensor.flatten(2, 4, 'features')
1054        self.assertEqual(out.names, ['N', 'C', 'features'])
1055        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
1056
1057        # list overload
1058        out = tensor.flatten(['D', 'H', 'W'], 'features')
1059        self.assertEqual(out.names, ['N', 'C', 'features'])
1060        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
1061
1062        # Non-contiguous flatten: N and H are not "adjacent" in memory.
1063        sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D'))
1064        sentences = sentences.transpose('T', 'H')
1065        out = sentences.flatten('N', 'H', 'N_H')
1066        self.assertEqual(out.names, ['N_H', 'T', 'D'])
1067
1068        with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"):
1069            tensor.flatten(['D', 'L'], 'features')
1070
1071        with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
1072            tensor.flatten(['D', 'W'], 'features')
1073
1074        with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
1075            tensor.flatten(['H', 'D', 'W'], 'features')
1076
1077    def test_flatten_nodims(self):
1078        tensor = torch.empty((2, 3))
1079        with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
1080            tensor.flatten((), 'abcd')
1081
1082    def test_flatten_index_error(self):
1083        tensor = torch.randn(1, 2)
1084        with self.assertRaisesRegex(IndexError,
1085                                    r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
1086            tensor.flatten(0, 2)
1087        with self.assertRaisesRegex(IndexError,
1088                                    r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
1089            tensor.flatten(0, 2, 'N')
1090        with self.assertRaisesRegex(RuntimeError,
1091                                    r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
1092            tensor.flatten(1, 0)
1093        with self.assertRaisesRegex(RuntimeError,
1094                                    r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
1095            tensor.flatten(1, 0, 'N')
1096
1097    def test_unflatten(self):
1098        # test args: tensor, int, namedshape
1099        self.assertTrue(torch.equal(
1100            torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))),
1101            torch.ones(2, 2, names=('A', 'B'))))
1102        self.assertTrue(torch.equal(
1103            torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]),
1104            torch.ones(2, 2, names=('A', 'B'))))
1105        self.assertTrue(torch.equal(
1106            torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])),
1107            torch.ones(2, 2, names=('A', 'B'))))
1108        self.assertTrue(torch.equal(
1109            torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
1110            torch.ones(2, 10, names=('A', 'B1'))))
1111        self.assertTrue(torch.equal(
1112            torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B'))
1113                 .unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])),
1114            torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4'))))
1115        self.assertTrue(torch.equal(
1116            torch.ones(2, 0, names=('A', 'B'))
1117                 .unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
1118            torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))
1119
1120        # test args: namedtensor, str, namedshape
1121        self.assertTrue(torch.equal(
1122            torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
1123            torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
1124
1125        # test invalid args: namedtensor, str, sizes
1126        with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
1127            torch.tensor([1], names=('A',)).unflatten('A', (1, 1))
1128
1129        # test invalid args: namedtensor, int, sizes
1130        with self.assertRaisesRegex(RuntimeError, r"input is a named tensor but no names were given for unflattened sizes"):
1131            torch.tensor([1], names=("A",)).unflatten(0, (1, 1))
1132
1133        with self.assertRaisesRegex(RuntimeError,
1134                                    r"Provided sizes \[3, -1\] don't multiply up to the "
1135                                    r"size of dim 1 \('B': 4\) in Tensor\['A', 'B'\]"):
1136            torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 3), ('B2', -1)))
1137
1138        with self.assertRaisesRegex(RuntimeError,
1139                                    r"the unspecified dimension size -1 can be any value and is ambiguous"):
1140            torch.ones(2, 0, names=('A', 'B')).unflatten('B', (('B1', 0), ('B2', -1)))
1141
1142        tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K'))
1143
1144        # accepts OrderedDict
1145        out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5))))
1146        self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K'))
1147        self.assertEqual(out.shape, (7, 2, 3, 5, 11))
1148
1149        # Unflatten left-most
1150        out = tensor.unflatten('N', (('N', 7), ('H', 1)))
1151        self.assertEqual(out.names, ('N', 'H', 'D', 'K'))
1152        self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11))
1153
1154        # Unflatten right-most
1155        out = tensor.unflatten('K', (('K', 11), ('H', 1)))
1156        self.assertEqual(out.names, ('N', 'D', 'K', 'H'))
1157        self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1))
1158
1159        with self.assertRaisesRegex(RuntimeError, "don't multiply up to"):
1160            tensor.unflatten('D', (('H', 3), ('W', 5)))
1161
1162        with self.assertRaisesRegex(RuntimeError, 'sizes must be non-empty'):
1163            tensor.unflatten('D', None)
1164
1165        with self.assertRaisesRegex(RuntimeError, 'non-empty'):
1166            tensor.unflatten('D', OrderedDict())
1167
1168    def test_unsupported_op_error_msg(self):
1169        named = torch.randn(3, 3, names=('N', 'C'))
1170        with self.assertRaisesRegex(
1171                RuntimeError, r"pdist.+is not yet supported with named tensors"):
1172            torch.pdist(named)
1173        with self.assertRaisesRegex(
1174                RuntimeError, r"as_strided_.+is not yet supported with named tensors"):
1175            named.as_strided_((3, 3), (3, 1))
1176
1177    def test_reduction_fns(self):
1178        def check_output(output, expected_names):
1179            if isinstance(output, torch.Tensor):
1180                self.assertEqual(output.names, expected_names)
1181                return
1182            for out in output:
1183                self.assertEqual(out.names, expected_names)
1184
1185        def sum_all_outputs(output):
1186            if isinstance(output, torch.Tensor):
1187                return output.sum()
1188            result = 0
1189            for out in output:
1190                result = out + result
1191            return result.sum()
1192
1193        def test_simple_reduce(op, device):
1194            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1195            check_output(op(t, 1), ['N', 'L'])
1196            check_output(op(t, -1), ['N', 'C'])
1197            check_output(op(t, 'C'), ['N', 'L'])
1198            ops_support_dim_none = [
1199                'sum',
1200                'mean',
1201                'std',
1202                'var',
1203                'std_mean',
1204                'var_mean',
1205                'nanmean',
1206                'nansum',
1207            ]
1208            if op.__name__ in ops_support_dim_none:
1209                check_output(op(t, None), [])
1210            else:
1211                with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
1212                    op(t, None)
1213            with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):
1214                op(t, 'H')
1215
1216        def test_autograd_supports_dimname_overload(op, device):
1217            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device, requires_grad=True)
1218            sum_all_outputs(op(t, 'C')).backward()
1219            self.assertIsNotNone(t.grad)
1220
1221        def test_complete_reduce(op, device):
1222            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1223            check_output(op(t), [])
1224
1225        def test_multidim_reduce(op, device):
1226            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1227
1228            check_output(op(t, [1, 2]), ['N'])
1229            check_output(op(t, [0, -1]), ['C'])
1230            check_output(op(t, ['C', 'L']), ['N'])
1231            with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
1232                op(t, [None, 'C'])
1233
1234        def test_out_variant(op, output_lambda, device):
1235            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1236            if output_lambda:
1237                out = output_lambda(t)
1238            else:
1239                out = torch.empty([0], device=device)
1240            op(t, 'C', out=out)
1241            check_output(out, ['N', 'L'])
1242
1243        def test_keepdim(op, device):
1244            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1245            check_output(op(t, 'C', keepdim=True), ['N', 'C', 'L'])
1246
1247        def values_and_indices(t):
1248            return (torch.empty([0], device=t.device),
1249                    torch.empty([0], device=t.device, dtype=torch.long))
1250
1251        def kthvalue_wrapper(tensor, *args, **kwargs):
1252            # Return the 0-th value
1253            return torch.kthvalue(tensor, 1, *args, **kwargs)
1254
1255        Case = namedtuple('Case', [
1256            'op',
1257            'supports_complete_reduce',
1258            'supports_multidim_reduce',
1259            'supports_out_variant',
1260            'supports_keepdim',
1261            'output_lambda',
1262        ])
1263
1264        tests = [
1265            Case(torch.sum, True, True, True, True, None),
1266            Case(torch.prod, True, False, True, True, None),
1267            Case(torch.mean, True, True, True, True, None),
1268            Case(torch.var, True, True, True, True, None),
1269            Case(torch.std, True, True, True, True, None),
1270            Case(torch.std_mean, True, True, False, True, None),
1271            Case(torch.var_mean, True, True, False, True, None),
1272            Case(torch.min, True, False, True, True, values_and_indices),
1273            Case(torch.max, True, False, True, True, values_and_indices),
1274            Case(torch.unbind, False, False, False, False, None),
1275            Case(torch.logsumexp, False, True, True, True, None),
1276            Case(torch.mode, False, False, True, True, values_and_indices),
1277            Case(kthvalue_wrapper, False, False, True, True, values_and_indices),
1278            Case(torch.median, True, False, True, True, values_and_indices),
1279            Case(torch.nanmedian, True, False, True, True, values_and_indices),
1280        ]
1281
1282        for testcase, device in itertools.product(tests, get_all_device_types()):
1283            op = testcase.op
1284            test_simple_reduce(op, device)
1285            test_autograd_supports_dimname_overload(op, device)
1286
1287            if testcase.supports_keepdim:
1288                test_keepdim(op, device)
1289            if testcase.supports_out_variant:
1290                test_out_variant(op, testcase.output_lambda, device)
1291            if testcase.supports_complete_reduce:
1292                test_complete_reduce(op, device)
1293            if testcase.supports_multidim_reduce:
1294                test_multidim_reduce(op, device)
1295
1296    def test_masked_select(self):
1297        # simple
1298        self._test_name_inference(
1299            torch.masked_select,
1300            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
1301            expected_names=[None])
1302
1303        # left broadcast
1304        self._test_name_inference(
1305            torch.masked_select,
1306            (create('C:3'), (create('2,3') > 0).rename('N', 'C')),
1307            expected_names=[None])
1308
1309        # right broadcast
1310        self._test_name_inference(
1311            torch.masked_select,
1312            (create('N:2,C:3'), (create('3') > 0).rename('C')),
1313            expected_names=[None])
1314
1315        # error
1316        self._test_name_inference(
1317            torch.masked_select,
1318            (create('N:2,C:3'), (create('3') > 0).rename('D')),
1319            maybe_raises_regex='do not match')
1320
1321        # out=
1322        self._test_name_inference(
1323            out_fn(torch.masked_select),
1324            (create('0'), create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
1325            expected_names=[None])
1326
1327    def test_cat(self):
1328        # simple
1329        self._test_name_inference(
1330            torch.cat,
1331            [[create('N:2,C:3'), create('N:2,C:3')]],
1332            expected_names=['N', 'C'])
1333
1334        # error: zero dim
1335        self._test_name_inference(
1336            torch.cat,
1337            [[create(''), create('')]],
1338            maybe_raises_regex='zero-dim')
1339
1340        # error: names don't match
1341        self._test_name_inference(
1342            torch.cat,
1343            [[create('N:2,C:3'), create('C:3,N:2')]],
1344            maybe_raises_regex='do not match')
1345
1346        # error: different number of dims
1347        self._test_name_inference(
1348            torch.cat,
1349            [[create('N:2,C:3'), create('C:3')]],
1350            maybe_raises_regex='must have same number of dimensions')
1351
1352        # out=
1353        self._test_name_inference(
1354            out_fn(torch.cat),
1355            [create('0'), [create('N:2,C:3'), create('N:2,C:3')]],
1356            expected_names=['N', 'C'])
1357
1358    def test_masked_fill(self):
1359        # simple
1360        self._test_name_inference(
1361            Tensor.masked_fill,
1362            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1363            expected_names=['N', 'C'])
1364
1365        # left broadcast
1366        self._test_name_inference(
1367            Tensor.masked_fill,
1368            (create('C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1369            maybe_raises_regex="must be less than or equal to")
1370
1371        # right broadcast
1372        self._test_name_inference(
1373            Tensor.masked_fill,
1374            (create('N:2,C:3'), (create('3') > 0).rename('C'), 3.14),
1375            expected_names=['N', 'C'])
1376
1377        # error
1378        self._test_name_inference(
1379            Tensor.masked_fill,
1380            (create('N:2,C:3'), (create('3') > 0).rename('D'), 3.14),
1381            maybe_raises_regex='do not match')
1382
1383        # inplace
1384        self._test_name_inference(
1385            Tensor.masked_fill_,
1386            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1387            expected_names=['N', 'C'])
1388
1389        # inplace, computed names don't match output tensor names
1390        self._test_name_inference(
1391            Tensor.masked_fill_,
1392            (create('N:2,None:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1393            maybe_raises_regex="not the same as the computed output names")
1394
1395
1396    def test_using_seen_interned_string_doesnt_bump_refcount(self):
1397        def see_name():
1398            seen_name = 'N'
1399            pass_name_to_python_arg_parser(seen_name)
1400
1401        see_name()
1402        seen_name = 'N'
1403        old_refcnt = sys.getrefcount(seen_name)
1404
1405        pass_name_to_python_arg_parser(seen_name)
1406
1407        new_refcnt = sys.getrefcount(seen_name)
1408        self.assertEqual(new_refcnt, old_refcnt)
1409
1410    # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464
1411    @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+")
1412    def test_using_unseen_interned_string_bumps_refcount_permanently(self):
1413        # Please don't use this as a name in a different test.
1414        unseen_name = 'abcdefghi'
1415        old_refcnt = sys.getrefcount(unseen_name)
1416
1417        pass_name_to_python_arg_parser(unseen_name)
1418
1419        new_refcnt = sys.getrefcount(unseen_name)
1420        self.assertEqual(new_refcnt, old_refcnt + 1)
1421
1422    # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464
1423    @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+")
1424    def test_using_unseen_uninterned_string_refcounts(self):
1425        # Please don't use this as a name in a different test.
1426        # non-compile-time constants are not interned
1427        unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl'])
1428        interned_unseen_name = 'abcdefghijkl'
1429        self.assertFalse(unseen_name is interned_unseen_name)
1430
1431        old_uninterned_refcnt = sys.getrefcount(unseen_name)
1432        old_interned_refcnt = sys.getrefcount(interned_unseen_name)
1433
1434        pass_name_to_python_arg_parser(unseen_name)
1435
1436        new_uninterned_refcnt = sys.getrefcount(unseen_name)
1437        new_interned_refcnt = sys.getrefcount(interned_unseen_name)
1438
1439        # Internally, PyTorch should not hold a reference to the uninterned string
1440        self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt)
1441
1442        # Instead, we should hold a new reference to the interned version.
1443        self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1)
1444
1445    def _test_select(self, device):
1446        x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
1447        y = x.select(1, 1)
1448        self.assertEqual(y.names, ('N', 'H', 'W'))
1449
1450        y = x.select('C', 1)
1451        self.assertEqual(y.names, ('N', 'H', 'W'))
1452
1453        with self.assertRaisesRegex(
1454                RuntimeError, 'Please look up dimensions by name'):
1455            y = x.select(None, 1)
1456
1457    def test_select(self):
1458        self._test_select('cpu')
1459
1460    @unittest.skipIf(not TEST_CUDA, 'no CUDA')
1461    def test_select_cuda(self):
1462        self._test_select('cuda')
1463
1464    def _test_as_strided(self, device):
1465        x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
1466        y = x.as_strided([2 * 3 * 4 * 5], [1])
1467        self.assertEqual(y.names, (None,))
1468
1469    def test_as_strided(self):
1470        self._test_as_strided('cpu')
1471
1472    @unittest.skipIf(not TEST_CUDA, 'no CUDA')
1473    def test_as_strided_cuda(self):
1474        self._test_as_strided('cuda')
1475
1476    def test_no_jit_tracer_support(self):
1477        def foo(x):
1478            return torch.full(x.shape, 2., names=('N',))
1479
1480        with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
1481            x = torch.randn(3)
1482            torch.jit.trace(foo, example_inputs=x)
1483
1484        def bar(x):
1485            return x.select('N', 1)
1486
1487        with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
1488            x = torch.randn(3)
1489            torch.jit.trace(bar, example_inputs=x)
1490
1491    def test_no_jit_script_support(self):
1492        @torch.jit.script
1493        def foo(x):
1494            return x + 1
1495
1496        with self.assertRaisesRegex(RuntimeError, 'NYI'):
1497            foo(torch.randn(2, 3, names=('N', 'C')))
1498
1499        @torch.jit.ignore
1500        def add_names(x):
1501            x.names = ('N', 'C')
1502
1503        @torch.jit.script
1504        def return_named_tensor(input):
1505            add_names(input)
1506            return input
1507
1508        with self.assertRaisesRegex(RuntimeError, "NYI"):
1509            return_named_tensor(torch.randn(1, 1))
1510
1511    def test_align_to(self):
1512        # trivial
1513        tensor = create('N:3')
1514        output = tensor.align_to('N')
1515        self.assertEqual(output.names, ['N'])
1516        self.assertEqual(output.shape, [3])
1517
1518        # unsqueeze behavior
1519        tensor = create('N:3')
1520        output = tensor.align_to('N', 'D')
1521        self.assertEqual(output.names, ['N', 'D'])
1522        self.assertEqual(output.shape, [3, 1])
1523
1524        # transpose behavior
1525        tensor = create('N:3,C:2')
1526        output = tensor.align_to('C', 'N')
1527        self.assertEqual(output.names, ['C', 'N'])
1528        self.assertEqual(output.shape, [2, 3])
1529
1530        # unsqueeze / transpose
1531        tensor = create('C:2,N:3,H:5')
1532        output = tensor.align_to('N', 'H', 'W', 'C')
1533        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
1534        self.assertEqual(output.shape, [3, 5, 1, 2])
1535
1536        # All input dimensions must be named
1537        with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"):
1538            create('None:2,C:3').align_to('N', 'C')
1539
1540        # not enough names
1541        with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"):
1542            create('N:2,C:3').align_to('C')
1543
1544        # names not found
1545        with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"):
1546            create('N:2,C:3').align_to('D', 'N')
1547
1548    def test_align_to_ellipsis(self):
1549        tensor = create('N:7,H:3,W:5,C:2')
1550
1551        # ... = ['N', 'H', 'W', 'C']
1552        output = tensor.align_to('...')
1553        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
1554        self.assertEqual(output.shape, [7, 3, 5, 2])
1555
1556        # ... = ['H', 'C']
1557        output = tensor.align_to('...', 'W', 'N')
1558        self.assertEqual(output.names, ['H', 'C', 'W', 'N'])
1559        self.assertEqual(output.shape, [3, 2, 5, 7])
1560
1561        # ... = ['N', 'W']
1562        output = tensor.align_to('H', 'C', '...')
1563        self.assertEqual(output.names, ['H', 'C', 'N', 'W'])
1564        self.assertEqual(output.shape, [3, 2, 7, 5])
1565
1566        # ... = ['H', 'C']
1567        output = tensor.align_to('W', '...', 'N')
1568        self.assertEqual(output.names, ['W', 'H', 'C', 'N'])
1569        self.assertEqual(output.shape, [5, 3, 2, 7])
1570
1571        # ... = []
1572        output = tensor.align_to('N', '...', 'C', 'D', 'H', 'W')
1573        self.assertEqual(output.names, ['N', 'C', 'D', 'H', 'W'])
1574        self.assertEqual(output.shape, [7, 2, 1, 3, 5])
1575
1576        # Input tensor partially named
1577        partially_named = create('None:2,None:3,None:5,C:7')
1578        output = partially_named.align_to('C', '...')
1579        self.assertEqual(output.names, ['C', None, None, None])
1580        self.assertEqual(output.shape, [7, 2, 3, 5])
1581
1582        with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"):
1583            partially_named.align_to('C', None, '...')
1584
1585        # Input order partially named
1586        with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"):
1587            tensor.align_to('...', 'N', None)
1588
1589        # Input order duplicate names
1590        with self.assertRaisesRegex(RuntimeError, "duplicate names"):
1591            tensor.align_to('...', 'N', 'N')
1592
1593    def test_align_as(self):
1594        # align_as calls align_to internally. align_to has pretty substantial tests,
1595        # so just test some basic things here.
1596        tensor = create('C:2,N:3,H:5')
1597        other = create('N:1,H:1,W:1,C:1')
1598        output = tensor.align_as(other)
1599        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
1600        self.assertEqual(output.shape, [3, 5, 1, 2])
1601
1602    @unittest.skip("Not implemented yet")
1603    def test_align_tensors_two_inputs(self):
1604        def _test(tensor_namedshape, align_names, expected_sizes, expected_error):
1605            tensor_names, tensor_sizes = tensor_namedshape
1606            tensor = torch.empty(*tensor_sizes, names=tensor_names)
1607            other = torch.empty([1] * len(align_names), names=align_names)
1608            if expected_error is not None:
1609                with self.assertRaisesRegex(RuntimeError, expected_error):
1610                    torch.align_tensors(tensor, other)
1611                return
1612
1613            output, _ = torch.align_tensors(tensor, other)
1614            self.assertEqual(output.shape, expected_sizes)
1615            self.assertEqual(output.names, align_names)
1616
1617        Case = namedtuple('Case', [
1618            'tensor_namedshape',
1619            'align_names',
1620            'expected_sizes',
1621            'expected_error',
1622        ])
1623
1624        tests = [
1625            # basic tests
1626            Case(tensor_namedshape=(['C'], [2]),
1627                 align_names=['C'],
1628                 expected_sizes=[2],
1629                 expected_error=None),
1630            Case(tensor_namedshape=(['C'], [2]),
1631                 align_names=['D'],
1632                 expected_sizes=None,
1633                 expected_error='not a subsequence'),
1634
1635            # single-dim alignment test
1636            Case(tensor_namedshape=(['C'], [2]),
1637                 align_names=['N', 'C'],
1638                 expected_sizes=[1, 2],
1639                 expected_error=None),
1640            Case(tensor_namedshape=[['N'], [2]],
1641                 align_names=['N', 'C'],
1642                 expected_sizes=[2, 1],
1643                 expected_error=None),
1644
1645            # multiple dim alignment test
1646            Case(tensor_namedshape=[['N', 'C'], [2, 3]],
1647                 align_names=['N', 'H', 'C', 'W'],
1648                 expected_sizes=[2, 1, 3, 1],
1649                 expected_error=None),
1650            Case(tensor_namedshape=[['N', 'C'], [2, 3]],
1651                 align_names=['C', 'H', 'N', 'W'],
1652                 expected_sizes=None,
1653                 expected_error='not a subsequence'),
1654
1655            # scalar tensor tests
1656            Case(tensor_namedshape=[None, [[]]],
1657                 align_names=['N', 'C'],
1658                 expected_sizes=[1, 1],
1659                 expected_error=None),
1660            Case(tensor_namedshape=[[], [[]]],
1661                 align_names=[None, None],
1662                 expected_sizes=[1, 1],
1663                 expected_error=None),
1664
1665            # unnamed tensor tests
1666            Case(tensor_namedshape=[None, [2, 3]],
1667                 align_names=[None, None],
1668                 expected_sizes=[2, 3],
1669                 expected_error=None),
1670            Case(tensor_namedshape=[None, [2, 3]],
1671                 align_names=[None, None, None],
1672                 expected_sizes=[1, 2, 3],
1673                 expected_error=None),
1674            Case(tensor_namedshape=[None, [2]],
1675                 align_names=['N'],
1676                 expected_sizes=None,
1677                 expected_error='not a subsequence'),
1678
1679            # unnamed dim alignment tests
1680            Case(tensor_namedshape=[[None], [2]],
1681                 align_names=['N', None],
1682                 expected_sizes=[1, 2],
1683                 expected_error=None),
1684            Case(tensor_namedshape=[[None], [2]],
1685                 align_names=['N', None, None, None],
1686                 expected_sizes=[1, 1, 1, 2],
1687                 expected_error=None),
1688            Case(tensor_namedshape=[['N'], [2]],
1689                 align_names=['N', None, None, None],
1690                 expected_sizes=[2, 1, 1, 1],
1691                 expected_error=None),
1692            Case(tensor_namedshape=[[None, 'N', None], [2, 3, 5]],
1693                 align_names=[None, None, 'N', None],
1694                 expected_sizes=[1, 2, 3, 5],
1695                 expected_error=None),
1696            Case(tensor_namedshape=[[None], [2]],
1697                 align_names=[None, 'N'],
1698                 expected_sizes=None,
1699                 expected_error='absolute position from the right'),
1700            Case(tensor_namedshape=[None, [2]],
1701                 align_names=[None, 'N'],
1702                 expected_sizes=None,
1703                 expected_error='absolute position from the right'),
1704            Case(tensor_namedshape=[[None, 'N'], [2, 3]],
1705                 align_names=[None, 'C', 'N'],
1706                 expected_sizes=None,
1707                 expected_error='absolute position from the right'),
1708        ]
1709
1710        for test in tests:
1711            _test(*test)
1712
1713    @unittest.skip("Not implemented yet")
1714    def test_align_tensors(self):
1715        def reference_fn(*tensors):
1716            longest_names = tensors[0].names
1717            for tensor in tensors:
1718                if len(tensor.names) > len(longest_names):
1719                    longest_names = tensor.names
1720            return [tensor.align_to(*longest_names) for tensor in tensors]
1721
1722        x = torch.empty(1, 1, names=('N', 'H'))
1723        y = torch.empty(2, 3, 5, names=('N', 'C', 'H'))
1724        z = torch.empty(2, names=('N',))
1725        output = torch.align_tensors(x, y, z)
1726        expected_tensors = reference_fn(x, y, z)
1727        for tensor, expected in zip(output, expected_tensors):
1728            self.assertTensorDataAndNamesEqual(tensor, expected)
1729
1730    def test_mm(self):
1731        for device in get_all_device_types():
1732            self._test_name_inference(
1733                torch.mm, device=device,
1734                args=(create('N:3,C:2'), create('W:2,H:5')),
1735                expected_names=('N', 'H'))
1736
1737            # left arg is unnamed
1738            self._test_name_inference(
1739                torch.mm, device=device,
1740                args=(create('3,2'), create('W:2,H:5')),
1741                expected_names=(None, 'H'))
1742
1743            # right arg is unnamed
1744            self._test_name_inference(
1745                torch.mm, device=device,
1746                args=(create('N:3,C:2'), create('2,5')),
1747                expected_names=('N', None))
1748
1749            # out=
1750            self._test_name_inference(
1751                out_fn(torch.mm), device=device,
1752                args=(create('0'), create('N:3,C:2'), create('W:2,H:5')),
1753                expected_names=('N', 'H'))
1754
1755            self._test_name_inference(
1756                torch.mm, device=device,
1757                args=(create('N:3,C:2'), create('W:2,N:5')),
1758                maybe_raises_regex='with duplicate names')
1759
1760    def test_expand(self):
1761        for device in get_all_device_types():
1762            self._test_name_inference(
1763                Tensor.expand, device=device,
1764                args=(create('D:1'), [3]), expected_names=('D',))
1765
1766            self._test_name_inference(
1767                Tensor.expand, device=device,
1768                args=(create('H:3,W:2'), [10, 3, 3, 2]),
1769                expected_names=(None, None, 'H', 'W'))
1770
1771            self._test_name_inference(
1772                Tensor.expand, device=device,
1773                args=(create('3, 2'), [10, 3, 3, 2]),
1774                expected_names=(None, None, None, None))
1775
1776    def test_addmm(self):
1777        for device in get_all_device_types():
1778            # full names
1779            self._test_name_inference(
1780                torch.addmm, device=device,
1781                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
1782                expected_names=('N', 'H'))
1783
1784            # no name on bias
1785            self._test_name_inference(
1786                torch.addmm, device=device,
1787                args=(create('3,5'), create('N:3,C:2'), create('W:2,H:5')),
1788                expected_names=('N', 'H'))
1789
1790            # partially named bias
1791            self._test_name_inference(
1792                torch.addmm, device=device,
1793                args=(create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
1794                expected_names=('N', 'H'))
1795
1796            # out=
1797            self._test_name_inference(
1798                out_fn(torch.addmm), device=device,
1799                args=(create('0'), create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
1800                expected_names=('N', 'H'))
1801
1802            # inplace
1803            self._test_name_inference(
1804                torch.Tensor.addmm_, device=device,
1805                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
1806                expected_names=('N', 'H'))
1807
1808            self._test_name_inference(
1809                torch.addmm, device=device,
1810                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,N:5')),
1811                maybe_raises_regex='with duplicate names')
1812
1813    def test_bmm(self):
1814        for device in get_all_device_types():
1815            # full names
1816            self._test_name_inference(
1817                torch.bmm, device=device,
1818                args=(create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
1819                expected_names=('N', 'A', 'B'))
1820
1821            # no name on left tensor
1822            self._test_name_inference(
1823                torch.bmm, device=device,
1824                args=(create('7,3,2'), create('N:7,A:2,B:5')),
1825                expected_names=('N', None, 'B'))
1826
1827            # no name on right tensor
1828            self._test_name_inference(
1829                torch.bmm, device=device,
1830                args=(create('N:7,A:3,B:2'), create('7,2,5')),
1831                expected_names=('N', 'A', None))
1832
1833            # out=
1834            self._test_name_inference(
1835                out_fn(torch.bmm), device=device,
1836                args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
1837                expected_names=('N', 'A', 'B'))
1838
1839            # duplicate names after mm
1840            self._test_name_inference(
1841                torch.bmm, device=device,
1842                args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
1843                maybe_raises_regex='with duplicate names')
1844
1845            # matching error (batch dimensions must be alignable)
1846            self._test_name_inference(
1847                torch.bmm, device=device,
1848                args=(create('N:3,A:3,B:3'), create('M:3,A:3,B:3')),
1849                maybe_raises_regex='do not match')
1850
1851            # misalignment (batch dimension is getting contracted)
1852            self._test_name_inference(
1853                torch.bmm, device=device,
1854                args=(create('N:3,A:3,B:3'), create('None:3,N:3,B:3')),
1855                maybe_raises_regex='misaligned')
1856
1857    def test_matmul(self):
1858        for device in get_all_device_types():
1859            # input tensors are less than 1D
1860            self._test_name_inference(
1861                torch.matmul, device=device,
1862                args=(create(''), create('A:2')),
1863                maybe_raises_regex='at least 1D')
1864            self._test_name_inference(
1865                torch.matmul, device=device,
1866                args=(create('A:2'), create('')),
1867                maybe_raises_regex='at least 1D')
1868
1869            # 1D @ 1D
1870            self._test_name_inference(
1871                torch.matmul, device=device,
1872                args=(create('A:2'), create('B:2')),
1873                expected_names=[])
1874
1875            # ND @ 1D
1876            self._test_name_inference(
1877                torch.matmul, device=device,
1878                args=(create('A:3,C:2'), create('B:2')),
1879                expected_names=['A'])
1880            self._test_name_inference(
1881                torch.matmul, device=device,
1882                args=(create('A:5,C:3,D:2'), create('B:2')),
1883                expected_names=['A', 'C'])
1884
1885            # 1D @ ND
1886            self._test_name_inference(
1887                torch.matmul, device=device,
1888                args=(create('C:2'), create('A:2,B:3')),
1889                expected_names=['B'])
1890            self._test_name_inference(
1891                torch.matmul, device=device,
1892                args=(create('C:2'), create('A:3,B:2,D:5')),
1893                expected_names=['A', 'D'])
1894
1895            # 2D @ 2D
1896            self._test_name_inference(
1897                torch.matmul, device=device,
1898                args=(create('A:3,B:2'), create('A:2,B:3')),
1899                expected_names=['A', 'B'])
1900            self._test_name_inference(
1901                torch.matmul, device=device,
1902                args=(create('A:3,B:2'), create('B:2,A:5')),
1903                maybe_raises_regex='with duplicate names')
1904
1905            # ND @ ND where N >= 2
1906            self._test_name_inference(
1907                torch.matmul, device=device,
1908                args=(create('C:5,A:3,B:2'), create('A:2,B:3')),
1909                expected_names=['C', 'A', 'B'])
1910            self._test_name_inference(
1911                torch.matmul, device=device,
1912                args=(create('C:5,A:3,B:2'), create('None:1,A:2,B:3')),
1913                expected_names=['C', 'A', 'B'])
1914            self._test_name_inference(
1915                torch.matmul, device=device,
1916                args=(create('C:5,A:3,B:2'), create('None:2,None:1,A:2,B:3')),
1917                expected_names=[None, 'C', 'A', 'B'])
1918
1919            # out=
1920            self._test_name_inference(
1921                out_fn(torch.matmul), device=device,
1922                args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
1923                expected_names=('N', 'A', 'B'))
1924
1925            # duplicate names after mm
1926            self._test_name_inference(
1927                torch.bmm, device=device,
1928                args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
1929                maybe_raises_regex='with duplicate names')
1930
1931            # misalignment (batch dimension is getting contracted)
1932            self._test_name_inference(
1933                torch.matmul, device=device,
1934                args=(create('N:3,A:3,B:3'), create('A:3,N:3,B:3')),
1935                maybe_raises_regex='do not match')
1936
1937    def test_mv(self):
1938        for device in get_all_device_types():
1939            self._test_name_inference(
1940                torch.mv, device=device,
1941                args=(create('N:3,C:2'), create('W:2')),
1942                expected_names=('N',))
1943
1944            # left arg is unnamed
1945            self._test_name_inference(
1946                torch.mv, device=device,
1947                args=(create('3,2'), create('W:2')),
1948                expected_names=(None,))
1949
1950            # right arg is unnamed
1951            self._test_name_inference(
1952                torch.mv, device=device,
1953                args=(create('N:3,C:2'), create('2')),
1954                expected_names=('N',))
1955
1956            # out=
1957            self._test_name_inference(
1958                out_fn(torch.mv), device=device,
1959                args=(create('0'), create('N:3,C:2'), create('W:2')),
1960                expected_names=('N',))
1961
1962    def test_addmv(self):
1963        for device in get_all_device_types():
1964            # full names
1965            self._test_name_inference(
1966                torch.addmv, device=device,
1967                args=(create('N:3'), create('N:3,C:2'), create('H:2')),
1968                expected_names=['N'])
1969
1970            # no name on bias
1971            self._test_name_inference(
1972                torch.addmv, device=device,
1973                args=(create('3'), create('N:3,C:2'), create('H:2')),
1974                expected_names=('N',))
1975
1976            # out=
1977            self._test_name_inference(
1978                out_fn(torch.addmv), device=device,
1979                args=(create('0'), create('N:3'), create('N:3,C:2'), create('H:2')),
1980                expected_names=('N',))
1981
1982            # inplace
1983            self._test_name_inference(
1984                torch.Tensor.addmv_, device=device,
1985                args=(create('N:3'), create('N:3,C:2'), create('H:2')),
1986                expected_names=('N',))
1987
1988    def test_autograd_ignores_names(self):
1989        # sigmoid forward is supported by named tensors, but sigmoid_backward
1990        # is not (see native_functions.yaml). Test that autograd ignores names
1991        # and that the sigmoid_backward succeeds.
1992        x = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
1993        x.sigmoid().sum().backward()
1994
1995    def test_tensor_grad_is_unnamed(self):
1996        x = torch.randn(3, 3, names=(None, None), requires_grad=True)
1997        y = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
1998        (x * y).sum().backward()
1999
2000        # Check that names weren't propagated
2001        self.assertEqual(y.grad.names, [None, None])
2002        self.assertEqual(x.grad.names, [None, None])
2003
2004    def test_autograd_warns_named_grad(self):
2005        base = torch.randn(3, 3, names=('N', 'C'))
2006        named_grad = base.clone()
2007        base.requires_grad_()
2008
2009        with warnings.catch_warnings(record=True) as warns:
2010            # Cause all warnings to always be triggered.
2011            warnings.simplefilter("always")
2012            base.clone().backward(named_grad)
2013            self.assertEqual(len(warns), 1)
2014            self.assertTrue(
2015                str(warns[0].message).startswith('Autograd was passed a named grad tensor'))
2016
2017    def test_nyi_dimname_overload_msg(self):
2018        x = torch.randn(3, 3)
2019        with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"):
2020            x.squeeze_("N")
2021
2022    def test_dot(self):
2023        for device in get_all_device_types():
2024            # torch.dot ignores the names of both tensors
2025            self._test_name_inference(
2026                torch.dot, device=device,
2027                args=(create('C:2'), create('W:2')),
2028                expected_names=[])
2029
2030    def test_comparison_ops(self):
2031        for device in get_all_device_types():
2032            a = torch.randn(3, 3, names=('N', 'C'), device=device)
2033            b = torch.randn(3, 3, names=('N', 'C'), device=device)
2034            scalar = torch.randn([], device=device)
2035
2036            self.assertEqual((a == b).names, ['N', 'C'])
2037            self.assertEqual((a != b).names, ['N', 'C'])
2038            self.assertEqual((a > b).names, ['N', 'C'])
2039            self.assertEqual((a < b).names, ['N', 'C'])
2040            self.assertEqual((a >= b).names, ['N', 'C'])
2041            self.assertEqual((a <= b).names, ['N', 'C'])
2042
2043            self.assertEqual((a == 1).names, ['N', 'C'])
2044            self.assertEqual((a != 1).names, ['N', 'C'])
2045            self.assertEqual((a > 1).names, ['N', 'C'])
2046            self.assertEqual((a < 1).names, ['N', 'C'])
2047            self.assertEqual((a >= 1).names, ['N', 'C'])
2048            self.assertEqual((a <= 1).names, ['N', 'C'])
2049
2050            self.assertEqual((a == scalar).names, ['N', 'C'])
2051            self.assertEqual((a != scalar).names, ['N', 'C'])
2052            self.assertEqual((a > scalar).names, ['N', 'C'])
2053            self.assertEqual((a < scalar).names, ['N', 'C'])
2054            self.assertEqual((a >= scalar).names, ['N', 'C'])
2055            self.assertEqual((a <= scalar).names, ['N', 'C'])
2056
2057            res = torch.empty(3, 3, dtype=torch.bool, device=device)
2058            torch.eq(a, b, out=res)
2059            self.assertEqual(res.names, ['N', 'C'])
2060            torch.ne(a, b, out=res)
2061            self.assertEqual(res.names, ['N', 'C'])
2062            torch.lt(a, b, out=res)
2063            self.assertEqual(res.names, ['N', 'C'])
2064            torch.gt(a, b, out=res)
2065            self.assertEqual(res.names, ['N', 'C'])
2066            torch.le(a, b, out=res)
2067            self.assertEqual(res.names, ['N', 'C'])
2068            torch.ge(a, b, out=res)
2069            self.assertEqual(res.names, ['N', 'C'])
2070
2071            res = torch.isnan(a)
2072            self.assertEqual(res.names, ['N', 'C'])
2073
2074            res = torch.isinf(a)
2075            self.assertEqual(res.names, ['N', 'C'])
2076
2077    def test_support_device_named_grad(self):
2078        named_tensor = torch.randn(3, 3, device='meta')
2079        with self.assertRaisesRegex(RuntimeError, 'NYI: named tensors only support CPU, CUDA'):
2080            named_tensor.rename_('N', 'C')
2081            named_tensor.names = ['N', 'C']
2082            named_tensor = torch.randn(3, 3, device='meta', names=['N', 'C'])
2083
2084
2085if __name__ == '__main__':
2086    run_tests()
2087