xref: /aosp_15_r20/external/pytorch/test/test_namedtensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: named tensor"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY
5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import get_all_device_types
8*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple, OrderedDict
9*da0073e9SAndroid Build Coastguard Workerimport itertools
10*da0073e9SAndroid Build Coastguard Workerimport functools
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
13*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
14*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.reduction import ForkingPickler
15*da0073e9SAndroid Build Coastguard Workerimport pickle
16*da0073e9SAndroid Build Coastguard Workerimport io
17*da0073e9SAndroid Build Coastguard Workerimport sys
18*da0073e9SAndroid Build Coastguard Workerimport warnings
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerdef pass_name_to_python_arg_parser(name):
22*da0073e9SAndroid Build Coastguard Worker    x = torch.empty(2, names=(name,))
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerdef flatten(lst):
26*da0073e9SAndroid Build Coastguard Worker    return [item for sublist in lst for item in sublist]
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard WorkerFunction = namedtuple('TestCase', ['name', 'lambd'])
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerdef parse_compressed_namedshape(string):
33*da0073e9SAndroid Build Coastguard Worker    # This is a metalanguage for describing a shape of a tensor compactly.
34*da0073e9SAndroid Build Coastguard Worker    # 'N:3,C:2' -> size = [3, 2], names: ['N', 'C']
35*da0073e9SAndroid Build Coastguard Worker    # 'None:3,None:2' -> size = [3, 2], names: ['None', 'None']
36*da0073e9SAndroid Build Coastguard Worker    # '3,2' -> size = [3, 2], names=None passed to ctor.
37*da0073e9SAndroid Build Coastguard Worker    def parse_name(maybe_name):
38*da0073e9SAndroid Build Coastguard Worker        maybe_name = maybe_name.strip()
39*da0073e9SAndroid Build Coastguard Worker        if maybe_name == 'None':
40*da0073e9SAndroid Build Coastguard Worker            return None
41*da0073e9SAndroid Build Coastguard Worker        return maybe_name
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    string = string.strip()
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    # '' -> size: [], names:None
46*da0073e9SAndroid Build Coastguard Worker    if len(string) == 0:
47*da0073e9SAndroid Build Coastguard Worker        return None, []
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    # '3, 2' -> size = [3, 2], None names.
50*da0073e9SAndroid Build Coastguard Worker    if ':' not in string:
51*da0073e9SAndroid Build Coastguard Worker        return None, [int(size) for size in string.split(',')]
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    dims = string.split(',')
54*da0073e9SAndroid Build Coastguard Worker    tuples = [dim.split(':') for dim in dims]
55*da0073e9SAndroid Build Coastguard Worker    return zip(*[(parse_name(name), int(size)) for name, size in tuples])
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerdef create(namedshape, factory=torch.randn):
59*da0073e9SAndroid Build Coastguard Worker    # namedshape: str
60*da0073e9SAndroid Build Coastguard Worker    names, shape = parse_compressed_namedshape(namedshape)
61*da0073e9SAndroid Build Coastguard Worker    return factory(shape, names=names)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerdef out_fn(operator):
65*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(operator)
66*da0073e9SAndroid Build Coastguard Worker    def fn(*inputs):
67*da0073e9SAndroid Build Coastguard Worker        return operator(*inputs[1:], out=inputs[0])
68*da0073e9SAndroid Build Coastguard Worker    return fn
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerclass TestNamedTensor(TestCase):
72*da0073e9SAndroid Build Coastguard Worker    def test_aaa_must_run_first_check_experimental_warning(self):
73*da0073e9SAndroid Build Coastguard Worker        # TODO(rzou): It would be nice for this to be a "real" python warning.
74*da0073e9SAndroid Build Coastguard Worker        # Right now this error message only prints once and doesn't respect
75*da0073e9SAndroid Build Coastguard Worker        # warnings.simplefilter behavior (where python users can control whether
76*da0073e9SAndroid Build Coastguard Worker        # or not to display warnings once, all the time, or never).
77*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
78*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 3, names=('N', 'C'))
79*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(warns), 1)
80*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(str(warns[0].message).startswith(
81*da0073e9SAndroid Build Coastguard Worker                'Named tensors and all their associated APIs are an experimental feature'))
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    def test_trivial(self):
84*da0073e9SAndroid Build Coastguard Worker        pass
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def _test_name_inference(self, op, args=(), expected_names=(), device='cpu',
87*da0073e9SAndroid Build Coastguard Worker                             maybe_raises_regex=None):
88*da0073e9SAndroid Build Coastguard Worker        casted_args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg
89*da0073e9SAndroid Build Coastguard Worker                       for arg in args]
90*da0073e9SAndroid Build Coastguard Worker        if maybe_raises_regex is not None:
91*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, maybe_raises_regex):
92*da0073e9SAndroid Build Coastguard Worker                result = op(*args)
93*da0073e9SAndroid Build Coastguard Worker            return
94*da0073e9SAndroid Build Coastguard Worker        result = op(*args)
95*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.names, expected_names,
96*da0073e9SAndroid Build Coastguard Worker                         msg=f'Name inference for {op.__name__} on device {device} failed')
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    # TODO(rzou): Some form of this check should be added to self.assertEqual.
99*da0073e9SAndroid Build Coastguard Worker    # Right now I don't know what it should look like.
100*da0073e9SAndroid Build Coastguard Worker    def assertTensorDataAndNamesEqual(self, x, y):
101*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, y.names)
102*da0073e9SAndroid Build Coastguard Worker        unnamed_x = x.rename(None)
103*da0073e9SAndroid Build Coastguard Worker        unnamed_y = y.rename(None)
104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unnamed_x, unnamed_y)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def _test_factory(self, factory, device):
107*da0073e9SAndroid Build Coastguard Worker        x = factory([], device=device)
108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, ())
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        x = factory(1, 2, 3, device=device)
111*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, (None, None, None))
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker        x = factory(1, 2, 3, names=None, device=device)
114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, (None, None, None))
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device)
117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, ('N', 'T', 'D'))
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, ('N', None, 'D'))
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device)
123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5'))
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
126*da0073e9SAndroid Build Coastguard Worker                                    'a valid identifier contains only'):
127*da0073e9SAndroid Build Coastguard Worker            x = factory(2, names=('1',), device=device)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
130*da0073e9SAndroid Build Coastguard Worker                                    'a valid identifier contains only'):
131*da0073e9SAndroid Build Coastguard Worker            x = factory(2, names=('?',), device=device)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
134*da0073e9SAndroid Build Coastguard Worker            x = factory(2, 1, names=('N',), device=device)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'):
137*da0073e9SAndroid Build Coastguard Worker            x = factory(2, 1, names='N', device=device)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'):
140*da0073e9SAndroid Build Coastguard Worker            x = factory(2, 1, 1, names=('N', 'C', 'N'), device=device)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        names64 = ['A' * i for i in range(1, 65)]
143*da0073e9SAndroid Build Coastguard Worker        x = factory([1] * 64, names=names64, device=device)
144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.names, names64)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
147*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
148*da0073e9SAndroid Build Coastguard Worker                'only support up to 64 dims'):
149*da0073e9SAndroid Build Coastguard Worker            names65 = ['A' * i for i in range(1, 66)]
150*da0073e9SAndroid Build Coastguard Worker            x = factory([1] * 65, names=names64, device=device)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("not a bug: Dynamo causes the refcounts to be different")
153*da0073e9SAndroid Build Coastguard Worker    def test_none_names_refcount(self, N=10):
154*da0073e9SAndroid Build Coastguard Worker        def scope():
155*da0073e9SAndroid Build Coastguard Worker            unnamed = torch.empty(2, 3)
156*da0073e9SAndroid Build Coastguard Worker            unnamed.names  # materialize [None, None]
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        prev_none_refcnt = sys.getrefcount(None)
159*da0073e9SAndroid Build Coastguard Worker        # Ran it N times to reduce flakiness
160*da0073e9SAndroid Build Coastguard Worker        [scope() for i in range(N)]
161*da0073e9SAndroid Build Coastguard Worker        after_none_refcnt = sys.getrefcount(None)
162*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2,
163*da0073e9SAndroid Build Coastguard Worker                        msg='Using tensor.names should not change '
164*da0073e9SAndroid Build Coastguard Worker                            'the refcount of Py_None')
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    def test_has_names(self):
167*da0073e9SAndroid Build Coastguard Worker        unnamed = torch.empty(2, 3)
168*da0073e9SAndroid Build Coastguard Worker        none_named = torch.empty(2, 3, names=(None, None))
169*da0073e9SAndroid Build Coastguard Worker        partially_named = torch.empty(2, 3, names=('N', None))
170*da0073e9SAndroid Build Coastguard Worker        fully_named = torch.empty(2, 3, names=('N', 'C'))
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(unnamed.has_names())
173*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(none_named.has_names())
174*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(partially_named.has_names())
175*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(fully_named.has_names())
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    def test_py3_ellipsis(self):
178*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3, 5, 7)
179*da0073e9SAndroid Build Coastguard Worker        output = tensor.refine_names('N', ..., 'C')
180*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N', None, None, 'C'])
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    def test_refine_names(self):
183*da0073e9SAndroid Build Coastguard Worker        # Unnamed tensor -> Unnamed tensor
184*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(Tensor.refine_names,
185*da0073e9SAndroid Build Coastguard Worker                                  [create('None:1,None:2,None:3'), 'N', 'C', 'H'],
186*da0073e9SAndroid Build Coastguard Worker                                  ['N', 'C', 'H'])
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        # Named tensor -> Named tensor
189*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(Tensor.refine_names,
190*da0073e9SAndroid Build Coastguard Worker                                  [create('N:1,C:2,H:3'), 'N', 'C', 'H'],
191*da0073e9SAndroid Build Coastguard Worker                                  ['N', 'C', 'H'])
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        # Partially named tensor -> named tensor
194*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(Tensor.refine_names,
195*da0073e9SAndroid Build Coastguard Worker                                  [create('None:1,C:2,None:3'), None, 'C', 'H'],
196*da0073e9SAndroid Build Coastguard Worker                                  [None, 'C', 'H'])
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        # Too few names
199*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(Tensor.refine_names,
200*da0073e9SAndroid Build Coastguard Worker                                  [create('None:2,None:3'), 'N', 'C', 'H'],
201*da0073e9SAndroid Build Coastguard Worker                                  maybe_raises_regex="different number of dims")
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker        # Cannot change Tensor[D] to Tensor[N]
204*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(Tensor.refine_names,
205*da0073e9SAndroid Build Coastguard Worker                                  [create('D:3'), 'N'],
206*da0073e9SAndroid Build Coastguard Worker                                  maybe_raises_regex="is different from")
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        # Cannot change Tensor[D] to Tensor[None]
209*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(Tensor.refine_names,
210*da0073e9SAndroid Build Coastguard Worker                                  [create('D:3'), None],
211*da0073e9SAndroid Build Coastguard Worker                                  maybe_raises_regex="'D' is more specific than None")
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        # globbing behavior exists
214*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(Tensor.refine_names,
215*da0073e9SAndroid Build Coastguard Worker                                  [create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'],
216*da0073e9SAndroid Build Coastguard Worker                                  [None, None, 'C', 'H'])
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    def test_detach(self):
219*da0073e9SAndroid Build Coastguard Worker        names = ['N']
220*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
221*da0073e9SAndroid Build Coastguard Worker            Tensor.detach_,
222*da0073e9SAndroid Build Coastguard Worker            [torch.randn(3, requires_grad=True, names=names)],
223*da0073e9SAndroid Build Coastguard Worker            names)
224*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
225*da0073e9SAndroid Build Coastguard Worker            Tensor.detach,
226*da0073e9SAndroid Build Coastguard Worker            [torch.randn(3, requires_grad=True, names=names)],
227*da0073e9SAndroid Build Coastguard Worker            names)
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker    def test_index_fill(self):
230*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
231*da0073e9SAndroid Build Coastguard Worker            expected_names = ('N', 'C')
232*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 5, device=device, names=expected_names)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker            output = x.index_fill_('C', torch.tensor([0, 1], device=device), 5)
235*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output.names, expected_names)
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker            output = x.index_fill_('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
238*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output.names, expected_names)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker            output = x.index_fill('C', torch.tensor([0, 1], device=device), 5)
241*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output.names, expected_names)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker            output = x.index_fill('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
244*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output.names, expected_names)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    def test_equal(self):
247*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
248*da0073e9SAndroid Build Coastguard Worker            tensor = torch.randn(2, 3, device=device)
249*da0073e9SAndroid Build Coastguard Worker            other = tensor.clone()
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.equal(tensor.rename('N', 'C'), other.rename('N', 'C')))
252*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(tensor.rename('M', 'C'), other.rename('N', 'C')))
253*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.equal(tensor.rename(None, 'C'), other.rename('N', 'C')))
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    def test_squeeze(self):
256*da0073e9SAndroid Build Coastguard Worker        x = create('N:3,C:1,H:1,W:1')
257*da0073e9SAndroid Build Coastguard Worker        output = x.squeeze('C')
258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N', 'H', 'W'])
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        output = x.squeeze()
261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N'])
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    def test_repr(self):
264*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.zeros(2, 3).rename_('N', 'C')
265*da0073e9SAndroid Build Coastguard Worker        expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]], names=('N', 'C'))"
266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(named_tensor), expected)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        unnamed_tensor = torch.zeros(2, 3)
269*da0073e9SAndroid Build Coastguard Worker        expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]])"
270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(unnamed_tensor), expected)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        none_named_tensor = torch.zeros(2, 3).rename_(None, None)
273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(none_named_tensor), expected)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    def test_diagonal(self):
276*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.zeros(2, 3, 5, 7, names=list('ABCD'))
277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.diagonal().names, ['C', 'D', None])
278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.diagonal(1, 3).names, ['A', 'C', None])
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.diagonal(outdim='E', dim1='B', dim2='D').names,
281*da0073e9SAndroid Build Coastguard Worker                         ['A', 'C', 'E'])
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    def test_max_pooling(self):
284*da0073e9SAndroid Build Coastguard Worker        def check_tuple_return(op, inputs, expected_names):
285*da0073e9SAndroid Build Coastguard Worker            values, indices = op(*inputs)
286*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values.names, expected_names)
287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(indices.names, expected_names)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker            named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list('ABC'))
292*da0073e9SAndroid Build Coastguard Worker            named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list('ABCD'))
293*da0073e9SAndroid Build Coastguard Worker            named_tensor_3d = torch.zeros(2, 3, 5, 7, 9, device=device, names=list('ABCDE'))
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names)
296*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names)
297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names)
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker            check_tuple_return(F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names)
300*da0073e9SAndroid Build Coastguard Worker            check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names)
301*da0073e9SAndroid Build Coastguard Worker            check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    def test_max_pooling_without_names_does_not_warn(self):
304*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
305*da0073e9SAndroid Build Coastguard Worker            tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True)
306*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as warns:
307*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("always")
308*da0073e9SAndroid Build Coastguard Worker                result = F.max_pool2d(tensor_2d, [2, 2])
309*da0073e9SAndroid Build Coastguard Worker                result.sum().backward()
310*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(warns), 0)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    def test_no_save_support(self):
313*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
314*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
315*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "NYI"):
316*da0073e9SAndroid Build Coastguard Worker            torch.save(named_tensor, buf)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    def test_no_pickle_support(self):
319*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
320*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "NYI"):
321*da0073e9SAndroid Build Coastguard Worker            serialized = pickle.dumps(named_tensor)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    def test_no_multiprocessing_support(self):
324*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.zeros(2, 3, names=('N', 'C'))
325*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
326*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "NYI"):
327*da0073e9SAndroid Build Coastguard Worker            ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    def test_big_tensor_repr_has_names(self):
330*da0073e9SAndroid Build Coastguard Worker        def check_repr(named_tensor):
331*da0073e9SAndroid Build Coastguard Worker            unnamed_tensor = named_tensor.rename(None)
332*da0073e9SAndroid Build Coastguard Worker            names_tag = f'names={named_tensor.names}'
333*da0073e9SAndroid Build Coastguard Worker            self.assertIn(names_tag, repr(named_tensor))
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        check_repr(torch.randn(128, 3, 64, 64, names=('N', 'C', 'H', 'W')))
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    def test_noncontig_contiguous(self):
338*da0073e9SAndroid Build Coastguard Worker        # This type of contiguous is special-cased and therefore needs its own test
339*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
340*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 3, device=device).t().rename_('N', 'C')
341*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.contiguous().names, ('N', 'C'))
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    def test_copy_transpose(self):
344*da0073e9SAndroid Build Coastguard Worker        # This type of copy is special-cased and therefore needs its own test
345*da0073e9SAndroid Build Coastguard Worker        def _test(self_names, other_names, expected_names):
346*da0073e9SAndroid Build Coastguard Worker            x = torch.empty(2, 5, names=self_names)
347*da0073e9SAndroid Build Coastguard Worker            y = torch.empty(5, 2).t().rename_(*other_names)
348*da0073e9SAndroid Build Coastguard Worker            x.copy_(y)
349*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.names, expected_names)
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        _test(('N', 'C'), ('N', 'C'), ('N', 'C'))
352*da0073e9SAndroid Build Coastguard Worker        _test(None, ('N', 'C'), ('N', 'C'))
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker    def test_rename_(self):
355*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(1, 1, names=('N', 'C'))
356*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.rename_(None).names, (None, None))
357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.rename_('H', 'W').names, ('H', 'W'))
358*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
359*da0073e9SAndroid Build Coastguard Worker            tensor.rename_('N', 'C', 'W')
360*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
361*da0073e9SAndroid Build Coastguard Worker            tensor.rename_('N', 'N')
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    def test_rename(self):
364*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(1, 1, names=('N', 'C'))
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.rename(None).names, (None, None))
367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.rename('H', 'W').names, ('H', 'W'))
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker        # Check that we didn't modify tensor.names
370*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, ('N', 'C'))
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
373*da0073e9SAndroid Build Coastguard Worker            tensor.rename('N', 'C', 'W')
374*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
375*da0073e9SAndroid Build Coastguard Worker            tensor.rename('N', 'N')
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'):
378*da0073e9SAndroid Build Coastguard Worker            tensor.rename(None, N='batch')
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        # rename returns a view on the tensor
381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.rename('H', 'W').data_ptr(), tensor.data_ptr())
382*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr())
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker    def test_rename_globber(self):
385*da0073e9SAndroid Build Coastguard Worker        scalar = torch.randn([])
386*da0073e9SAndroid Build Coastguard Worker        unnamed_tensor = torch.empty(1, 1, 1, 1)
387*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scalar.rename(None).names, [])
390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scalar.rename('...').names, [])
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker        # Check that it works with unnamed tensors
393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unnamed_tensor.rename('...').names, unnamed_tensor.names)
394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unnamed_tensor.rename('...', 'H', 'W').names,
395*da0073e9SAndroid Build Coastguard Worker                         [None, None, 'H', 'W'])
396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unnamed_tensor.rename('N', '...', 'W').names,
397*da0073e9SAndroid Build Coastguard Worker                         ['N', None, None, 'W'])
398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unnamed_tensor.rename('N', 'C', '...').names,
399*da0073e9SAndroid Build Coastguard Worker                         ['N', 'C', None, None])
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        # Check that it works with named tensors
402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename('...').names, named_tensor.names)
403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename('...', 'width').names,
404*da0073e9SAndroid Build Coastguard Worker                         ['N', 'C', 'H', 'width'])
405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename('batch', 'channels', '...', 'width').names,
406*da0073e9SAndroid Build Coastguard Worker                         ['batch', 'channels', 'H', 'width'])
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename('batch', '...').names,
408*da0073e9SAndroid Build Coastguard Worker                         ['batch', 'C', 'H', 'W'])
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker        # Test empty glob
411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(unnamed_tensor.rename('...', None, None, None, None).names,
412*da0073e9SAndroid Build Coastguard Worker                         [None, None, None, None])
413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename('N', 'C', 'H', '...', 'W').names,
414*da0073e9SAndroid Build Coastguard Worker                         ['N', 'C', 'H', 'W'])
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        # Multiple globs throw
417*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'More than one '):
418*da0073e9SAndroid Build Coastguard Worker            named_tensor.rename('...', 'channels', '...')
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    def test_rename_rename_map(self):
421*da0073e9SAndroid Build Coastguard Worker        scalar = torch.randn([])
422*da0073e9SAndroid Build Coastguard Worker        unnamed_tensor = torch.empty(1, 1, 1, 1)
423*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
426*da0073e9SAndroid Build Coastguard Worker            scalar.rename(N='batch')
427*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
428*da0073e9SAndroid Build Coastguard Worker            unnamed_tensor.rename(N='batch')
429*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
430*da0073e9SAndroid Build Coastguard Worker            named_tensor.rename(B='batch')
431*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
432*da0073e9SAndroid Build Coastguard Worker            named_tensor.rename(H='height', B='batch')
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename(N='batch').data_ptr(),
435*da0073e9SAndroid Build Coastguard Worker                         named_tensor.data_ptr())
436*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename(N='batch').names,
437*da0073e9SAndroid Build Coastguard Worker                         ['batch', 'C', 'H', 'W'])
438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(named_tensor.rename(N='batch', H='height').names,
439*da0073e9SAndroid Build Coastguard Worker                         ['batch', 'C', 'height', 'W'])
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    def test_set_names_property(self):
442*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(1, 1, names=('N', 'C'))
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        tensor.names = None
445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, (None, None))
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        tensor.names = ('N', 'W')
448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, ('N', 'W'))
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
451*da0073e9SAndroid Build Coastguard Worker            tensor.names = ['N', 'C', 'W']
452*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
453*da0073e9SAndroid Build Coastguard Worker            tensor.names = ['N', 'N']
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker    def test_factory_edge_cases(self):
456*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
457*da0073e9SAndroid Build Coastguard Worker            self._test_factory(torch.empty, device)
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker    def test_factory_coverage(self):
460*da0073e9SAndroid Build Coastguard Worker        def _test(factory, device):
461*da0073e9SAndroid Build Coastguard Worker            names = ('N', 'T', 'D')
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(0)
464*da0073e9SAndroid Build Coastguard Worker            result = factory(1, 2, 3, names=names, device=device)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker            torch.manual_seed(0)
467*da0073e9SAndroid Build Coastguard Worker            expected = factory(1, 2, 3, device=device).rename_(*names)
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker            self.assertTensorDataAndNamesEqual(result, expected)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        supported = [
472*da0073e9SAndroid Build Coastguard Worker            torch.ones,
473*da0073e9SAndroid Build Coastguard Worker            torch.rand,
474*da0073e9SAndroid Build Coastguard Worker            torch.randn,
475*da0073e9SAndroid Build Coastguard Worker            torch.zeros,
476*da0073e9SAndroid Build Coastguard Worker        ]
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        for op, device in itertools.product(supported, get_all_device_types()):
479*da0073e9SAndroid Build Coastguard Worker            _test(op, device)
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker        # Test torch.full
482*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
483*da0073e9SAndroid Build Coastguard Worker            names = ('N', 'T', 'D')
484*da0073e9SAndroid Build Coastguard Worker            result = torch.full([1, 2, 3], 2., names=names, device=device)
485*da0073e9SAndroid Build Coastguard Worker            expected = torch.full([1, 2, 3], 2., device=device).rename_(*names)
486*da0073e9SAndroid Build Coastguard Worker            self.assertTensorDataAndNamesEqual(result, expected)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker    def test_tensor_from_lists(self):
489*da0073e9SAndroid Build Coastguard Worker        names = ('N', 'C')
490*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([[1]], names=names)
491*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, names)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        names = ('N',)
494*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([1], names=names)
495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, names)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Number of names'):
498*da0073e9SAndroid Build Coastguard Worker            names = ('N', 'C')
499*da0073e9SAndroid Build Coastguard Worker            tensor = torch.tensor([1], names=names)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_NUMPY, "no numpy")
502*da0073e9SAndroid Build Coastguard Worker    def test_tensor_from_numpy(self):
503*da0073e9SAndroid Build Coastguard Worker        import numpy as np
504*da0073e9SAndroid Build Coastguard Worker        arr = np.array([[1]])
505*da0073e9SAndroid Build Coastguard Worker        names = ('N', 'C')
506*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([[1]], names=names)
507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, names)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker    def test_tensor_from_tensor(self):
510*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1)
511*da0073e9SAndroid Build Coastguard Worker        names = ('N', 'C')
512*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor(x, names=names)
513*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, names)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    def test_tensor_from_named_tensor(self):
516*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, names=('N', 'D'))
517*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor(x)
518*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, ('N', 'D'))
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        # there's no way to distinguish between names=None and not passing in names.
521*da0073e9SAndroid Build Coastguard Worker        # If the user passes in names=None they are asking for trouble.
522*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, names=('N', 'D'))
523*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor(x, names=None)
524*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.names, ('N', 'D'))
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 1, names=('N', 'D'))
527*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
528*da0073e9SAndroid Build Coastguard Worker            tensor = torch.tensor(x, names=('N', 'C'))
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    def test_size(self):
531*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
532*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.size('N'), 2)
533*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.size('C'), 5)
534*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
535*da0073e9SAndroid Build Coastguard Worker            t.size('channels')
536*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
537*da0073e9SAndroid Build Coastguard Worker            torch.empty(2, 3, 4).size('N')
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker    def test_stride(self):
540*da0073e9SAndroid Build Coastguard Worker        t = torch.empty(2, 3, 5, names=('N', None, 'C'))
541*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.stride('N'), 3 * 5)
542*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.stride('C'), 1)
543*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
544*da0073e9SAndroid Build Coastguard Worker            t.stride('channels')
545*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
546*da0073e9SAndroid Build Coastguard Worker            torch.empty(2, 3, 4).stride('N')
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker    def test_transpose_variants(self):
549*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
550*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.transpose('N', 'C').names, ['C', 'N', 'H', 'W'])
551*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.transpose(1, 3).names, ['N', 'W', 'H', 'C'])
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(2, 3, names=('N', 'C'))
554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.t().names, ['C', 'N'])
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    def test_resize(self):
557*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
558*da0073e9SAndroid Build Coastguard Worker            named = torch.randn(2, names=('N',), device=device)
559*da0073e9SAndroid Build Coastguard Worker            named.resize_([2])
560*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(named.names, ['N'])
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"):
563*da0073e9SAndroid Build Coastguard Worker                named.resize_([3])
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker            other_named = torch.randn(2, names=('N',), device=device)
566*da0073e9SAndroid Build Coastguard Worker            named.resize_as_(other_named)
567*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(other_named.names, ['N'])
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker            unnamed = torch.randn(2, device=device)
570*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
571*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, r'names .* are not the same as the computed output names'):
572*da0073e9SAndroid Build Coastguard Worker                named.resize_as_(unnamed)
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker            unnamed = torch.randn(1, device=device)
575*da0073e9SAndroid Build Coastguard Worker            unnamed.resize_as_(named)
576*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unnamed.names, ['N'])
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker    def test_cdist(self):
579*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
580*da0073e9SAndroid Build Coastguard Worker            tensor = torch.randn(3, 1, 2, 7, names=('M', 'N', 'first_group', 'features'),
581*da0073e9SAndroid Build Coastguard Worker                                 device=device)
582*da0073e9SAndroid Build Coastguard Worker            other = torch.randn(5, 11, 7, names=('N', 'second_group', 'features'),
583*da0073e9SAndroid Build Coastguard Worker                                device=device)
584*da0073e9SAndroid Build Coastguard Worker            result = torch.cdist(tensor, other)
585*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.names, ['M', 'N', 'first_group', 'second_group'])
586*da0073e9SAndroid Build Coastguard Worker
587*da0073e9SAndroid Build Coastguard Worker    def test_info_smoke(self):
588*da0073e9SAndroid Build Coastguard Worker        # Smoke test for info functions / methods / attributes on named tensors.
589*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(1, 1, names=('N', 'D'))
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker        tensor.device
592*da0073e9SAndroid Build Coastguard Worker        tensor.dtype
593*da0073e9SAndroid Build Coastguard Worker        tensor.get_device()
594*da0073e9SAndroid Build Coastguard Worker        tensor.is_complex()
595*da0073e9SAndroid Build Coastguard Worker        tensor.is_floating_point()
596*da0073e9SAndroid Build Coastguard Worker        tensor.is_nonzero()
597*da0073e9SAndroid Build Coastguard Worker        torch.is_same_size(tensor, tensor)
598*da0073e9SAndroid Build Coastguard Worker        torch.is_signed(tensor)
599*da0073e9SAndroid Build Coastguard Worker        tensor.layout
600*da0073e9SAndroid Build Coastguard Worker        tensor.numel()
601*da0073e9SAndroid Build Coastguard Worker        tensor.dim()
602*da0073e9SAndroid Build Coastguard Worker        tensor.element_size()
603*da0073e9SAndroid Build Coastguard Worker        tensor.is_contiguous()
604*da0073e9SAndroid Build Coastguard Worker        tensor.is_cuda
605*da0073e9SAndroid Build Coastguard Worker        tensor.is_leaf
606*da0073e9SAndroid Build Coastguard Worker        tensor.is_pinned()
607*da0073e9SAndroid Build Coastguard Worker        tensor.is_shared()
608*da0073e9SAndroid Build Coastguard Worker        tensor.is_sparse
609*da0073e9SAndroid Build Coastguard Worker        tensor.ndimension()
610*da0073e9SAndroid Build Coastguard Worker        tensor.nelement()
611*da0073e9SAndroid Build Coastguard Worker        tensor.shape
612*da0073e9SAndroid Build Coastguard Worker        tensor.size()
613*da0073e9SAndroid Build Coastguard Worker        tensor.size(1)
614*da0073e9SAndroid Build Coastguard Worker        tensor.storage()
615*da0073e9SAndroid Build Coastguard Worker        tensor.storage_offset()
616*da0073e9SAndroid Build Coastguard Worker        tensor.storage_type()
617*da0073e9SAndroid Build Coastguard Worker        tensor.stride()
618*da0073e9SAndroid Build Coastguard Worker        tensor.stride(1)
619*da0073e9SAndroid Build Coastguard Worker        tensor.data
620*da0073e9SAndroid Build Coastguard Worker        tensor.data_ptr()
621*da0073e9SAndroid Build Coastguard Worker        tensor.ndim
622*da0073e9SAndroid Build Coastguard Worker        tensor.item()
623*da0073e9SAndroid Build Coastguard Worker        tensor.type()
624*da0073e9SAndroid Build Coastguard Worker        tensor.is_shared()
625*da0073e9SAndroid Build Coastguard Worker        tensor.is_signed()
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker    def test_autograd_smoke(self):
628*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, names=('N', 'D'), requires_grad=True)
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker        y = x.clone()
631*da0073e9SAndroid Build Coastguard Worker        y.retain_grad()
632*da0073e9SAndroid Build Coastguard Worker        y.register_hook(lambda x: x)
633*da0073e9SAndroid Build Coastguard Worker
634*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker        # autograd related attributes
637*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(1, 1, names=('N', 'D'), requires_grad=True)
638*da0073e9SAndroid Build Coastguard Worker        tensor = tensor.relu()
639*da0073e9SAndroid Build Coastguard Worker        tensor.output_nr
640*da0073e9SAndroid Build Coastguard Worker        tensor.grad_fn
641*da0073e9SAndroid Build Coastguard Worker        tensor.requires_grad
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker    def test_split_fns_propagates_names(self):
644*da0073e9SAndroid Build Coastguard Worker        fns = [
645*da0073e9SAndroid Build Coastguard Worker            lambda x: x.split(1, 0),
646*da0073e9SAndroid Build Coastguard Worker            lambda x: x.split([1, 1], 1),
647*da0073e9SAndroid Build Coastguard Worker            lambda x: x.chunk(2, 0),
648*da0073e9SAndroid Build Coastguard Worker        ]
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
651*da0073e9SAndroid Build Coastguard Worker            orig_tensor = torch.empty(2, 2, names=('N', 'D'), device=device)
652*da0073e9SAndroid Build Coastguard Worker            for fn in fns:
653*da0073e9SAndroid Build Coastguard Worker                splits = fn(orig_tensor)
654*da0073e9SAndroid Build Coastguard Worker                for split in splits:
655*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(split.names, orig_tensor.names)
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker    def test_any_all(self):
658*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
659*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(3, dtype=torch.bool, device=device, names=('C',))
660*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.any().names, [])
661*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.all().names, [])
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker    def test_addcmul_addcdiv(self):
664*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
665*da0073e9SAndroid Build Coastguard Worker            names = ['N']
666*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(3, device=device, names=names)
667*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(3, device=device, names=names)
668*da0073e9SAndroid Build Coastguard Worker            # avoid division by 0
669*da0073e9SAndroid Build Coastguard Worker            c = torch.rand(3, device=device, names=names).clamp_min_(0.1)
670*da0073e9SAndroid Build Coastguard Worker            out = torch.randn(3, device=device, names=names)
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.addcmul(a, b, c).names, names)
673*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.addcmul(a, b, c, out=out).names, names)
674*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.addcmul_(b, c).names, names)
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.addcdiv(a, b, c).names, names)
677*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names)
678*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.addcdiv_(b, c).names, names)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker    def test_binary_ops(self):
681*da0073e9SAndroid Build Coastguard Worker        def test_basic(op):
682*da0073e9SAndroid Build Coastguard Worker            a = torch.empty(2, 3, names=('N', 'C'))
683*da0073e9SAndroid Build Coastguard Worker            b = torch.empty(3, 2, names=('C', 'N'))
684*da0073e9SAndroid Build Coastguard Worker            c = torch.empty(3, names=('C',))
685*da0073e9SAndroid Build Coastguard Worker            d = torch.empty(5, names=('W',))
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(a, a).names, ('N', 'C'))
688*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(a, c).names, ('N', 'C'))
689*da0073e9SAndroid Build Coastguard Worker            # TODO: dynamo will throw a slightly different
690*da0073e9SAndroid Build Coastguard Worker            # error message because it's adding fake tensors
691*da0073e9SAndroid Build Coastguard Worker            # `must match the size of` portion is the dynamo error
692*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
693*da0073e9SAndroid Build Coastguard Worker                op(a, d)
694*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"):
695*da0073e9SAndroid Build Coastguard Worker                op(a, b)
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        def test_wildcard(op):
698*da0073e9SAndroid Build Coastguard Worker            a = torch.empty(2, 3, names=('N', 'C'))
699*da0073e9SAndroid Build Coastguard Worker            c = torch.empty(2, 3, names=(None, 'C'))
700*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(a, c).names, ('N', 'C'))
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker            b = torch.empty(2, 3)
703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(a, b).names, ('N', 'C'))
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker            d = torch.empty(2, 3, names=('C', None))
706*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Misaligned"):
707*da0073e9SAndroid Build Coastguard Worker                op(d, c)
708*da0073e9SAndroid Build Coastguard Worker
709*da0073e9SAndroid Build Coastguard Worker        def test_mixed_unnamed_named(op, is_inplace):
710*da0073e9SAndroid Build Coastguard Worker            named2 = torch.randn(1, 1, names=('N', 'C'))
711*da0073e9SAndroid Build Coastguard Worker            unnamed1 = torch.randn(1)
712*da0073e9SAndroid Build Coastguard Worker            unnamed2 = torch.randn(1, 1)
713*da0073e9SAndroid Build Coastguard Worker            unnamed3 = torch.randn(1, 1, 1)
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker            def compute_expected_names(tensor, other):
716*da0073e9SAndroid Build Coastguard Worker                assert tensor.has_names() ^ other.has_names()
717*da0073e9SAndroid Build Coastguard Worker                named = tensor if tensor.has_names() else other
718*da0073e9SAndroid Build Coastguard Worker                unnamed = other if tensor.has_names() else tensor
719*da0073e9SAndroid Build Coastguard Worker                unnamed_dim = unnamed.dim()
720*da0073e9SAndroid Build Coastguard Worker                if unnamed_dim > named.dim():
721*da0073e9SAndroid Build Coastguard Worker                    return [None] * (unnamed_dim - named.dim()) + list(named.names)
722*da0073e9SAndroid Build Coastguard Worker                else:
723*da0073e9SAndroid Build Coastguard Worker                    return named.names
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker            inputs = itertools.chain(
726*da0073e9SAndroid Build Coastguard Worker                itertools.product([named2], [unnamed1, unnamed2, unnamed3]),
727*da0073e9SAndroid Build Coastguard Worker                itertools.product([unnamed1, unnamed2, unnamed3], [named2]),
728*da0073e9SAndroid Build Coastguard Worker            )
729*da0073e9SAndroid Build Coastguard Worker            if is_inplace:
730*da0073e9SAndroid Build Coastguard Worker                # In-place ops have the constraint that they must not change shape.
731*da0073e9SAndroid Build Coastguard Worker                inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()]
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker            for tensor, other in inputs:
734*da0073e9SAndroid Build Coastguard Worker                expected_names = compute_expected_names(tensor, other)
735*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(tensor, other).names, expected_names)
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        def method(name, *args, **kwargs):
738*da0073e9SAndroid Build Coastguard Worker            return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))]
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker        def function(name, *args, **kwargs):
741*da0073e9SAndroid Build Coastguard Worker            return [Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))]
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        def out_function(name, *args, **kwargs):
744*da0073e9SAndroid Build Coastguard Worker            out_fn = getattr(torch, name)
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker            def fn(a, b):
747*da0073e9SAndroid Build Coastguard Worker                result = torch.empty([0], dtype=a.dtype, device=a.device)
748*da0073e9SAndroid Build Coastguard Worker                out_fn(a, b, *args, out=result, **kwargs)
749*da0073e9SAndroid Build Coastguard Worker                return result
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker            return [Function(name, fn)]
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker        def fn_method_and_inplace(name, *args, **kwargs):
754*da0073e9SAndroid Build Coastguard Worker            return (
755*da0073e9SAndroid Build Coastguard Worker                method(name, *args, **kwargs) +
756*da0073e9SAndroid Build Coastguard Worker                method(name + '_', *args, **kwargs) +
757*da0073e9SAndroid Build Coastguard Worker                out_function(name, *args, **kwargs)
758*da0073e9SAndroid Build Coastguard Worker            )
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        tests = [
761*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('add'),
762*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('div'),
763*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('mul'),
764*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('sub'),
765*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('pow'),
766*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('atan2'),
767*da0073e9SAndroid Build Coastguard Worker            method('copy_'),
768*da0073e9SAndroid Build Coastguard Worker            function('floor_divide'),
769*da0073e9SAndroid Build Coastguard Worker            function('true_divide'),
770*da0073e9SAndroid Build Coastguard Worker        ]
771*da0073e9SAndroid Build Coastguard Worker        tests = flatten(tests)
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        for name, op in tests:
774*da0073e9SAndroid Build Coastguard Worker            test_basic(op)
775*da0073e9SAndroid Build Coastguard Worker            test_wildcard(op)
776*da0073e9SAndroid Build Coastguard Worker            test_mixed_unnamed_named(op, is_inplace=name.endswith('_'))
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker    def test_logical_ops(self):
779*da0073e9SAndroid Build Coastguard Worker        # Implemented via TensorIterator, so just check that each version
780*da0073e9SAndroid Build Coastguard Worker        # (out-of-place, inplace, out=) propagates names.
781*da0073e9SAndroid Build Coastguard Worker        def zeros(*args, **kwargs):
782*da0073e9SAndroid Build Coastguard Worker            return torch.zeros(*args, dtype=torch.bool, **kwargs)
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker        for op in ('logical_xor', 'logical_and', 'logical_or'):
785*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
786*da0073e9SAndroid Build Coastguard Worker                getattr(torch, op),
787*da0073e9SAndroid Build Coastguard Worker                (create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
788*da0073e9SAndroid Build Coastguard Worker                expected_names=['N', 'C'])
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
791*da0073e9SAndroid Build Coastguard Worker                getattr(Tensor, op + '_'),
792*da0073e9SAndroid Build Coastguard Worker                (create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
793*da0073e9SAndroid Build Coastguard Worker                expected_names=['N', 'C'])
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
796*da0073e9SAndroid Build Coastguard Worker                lambda out, x, y: getattr(torch, op)(x, y, out=out),
797*da0073e9SAndroid Build Coastguard Worker                (create('0', zeros), create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
798*da0073e9SAndroid Build Coastguard Worker                expected_names=['N', 'C'])
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker    def test_pow_special(self):
801*da0073e9SAndroid Build Coastguard Worker        # There are a few pow cases that don't go through TensorIterator.
802*da0073e9SAndroid Build Coastguard Worker        # Test them here.
803*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
804*da0073e9SAndroid Build Coastguard Worker            named = torch.randn(2, 3, names=('N', 'C'), device=device)
805*da0073e9SAndroid Build Coastguard Worker            unnamed = torch.randn([0], device=device)
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker            result = torch.pow(named, 0, out=unnamed.clone())
808*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.names, named.names)
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker            result = torch.pow(named, 1, out=unnamed.clone())
811*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.names, named.names)
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker            result = torch.pow(1, named, out=unnamed.clone())
814*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.names, named.names)
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    def test_out_fn_semantics(self):
817*da0073e9SAndroid Build Coastguard Worker        out_fn = torch.abs
818*da0073e9SAndroid Build Coastguard Worker        unnamed_tensor = torch.randn(3, 2)
819*da0073e9SAndroid Build Coastguard Worker        none_named_tensor = torch.randn(3, 2, names=(None, None))
820*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.randn(3, 2, names=('N', 'C'))
821*da0073e9SAndroid Build Coastguard Worker        partially_named_tensor = torch.randn(3, 2, names=('N', None))
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
824*da0073e9SAndroid Build Coastguard Worker            out_fn(partially_named_tensor, out=named_tensor)
825*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
826*da0073e9SAndroid Build Coastguard Worker            out_fn(named_tensor, out=partially_named_tensor)
827*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
828*da0073e9SAndroid Build Coastguard Worker            out_fn(none_named_tensor, out=named_tensor)
829*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
830*da0073e9SAndroid Build Coastguard Worker            out_fn(unnamed_tensor, out=named_tensor)
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker        output = torch.randn(3, 2)
833*da0073e9SAndroid Build Coastguard Worker        out_fn(unnamed_tensor, out=output)
834*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(output.has_names())
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker        output = torch.randn(3, 2, names=(None, None))
837*da0073e9SAndroid Build Coastguard Worker        out_fn(named_tensor, out=output)
838*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, named_tensor.names)
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker        output = torch.randn(3, 2)
841*da0073e9SAndroid Build Coastguard Worker        out_fn(named_tensor, out=output)
842*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, named_tensor.names)
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker        output = torch.randn(3, 2, names=(None, None))
845*da0073e9SAndroid Build Coastguard Worker        out_fn(unnamed_tensor, out=output)
846*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(output.has_names())
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker    def test_unary_propagate_names_fns(self):
849*da0073e9SAndroid Build Coastguard Worker        def _test(testcase, names=('N', 'D'), device='cpu'):
850*da0073e9SAndroid Build Coastguard Worker            sizes = [2] * len(names)
851*da0073e9SAndroid Build Coastguard Worker            tensor = torch.empty(sizes, names=names, device=device)
852*da0073e9SAndroid Build Coastguard Worker            try:
853*da0073e9SAndroid Build Coastguard Worker                out = testcase.lambd(tensor)
854*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as err:
855*da0073e9SAndroid Build Coastguard Worker                # Get a better error message by catching the error and asserting.
856*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(f'{testcase.name}: {err}') from err
857*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.names, tensor.names,
858*da0073e9SAndroid Build Coastguard Worker                             msg=testcase.name)
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        def fn(name, *args, **kwargs):
861*da0073e9SAndroid Build Coastguard Worker            return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        def method(name, *args, **kwargs):
864*da0073e9SAndroid Build Coastguard Worker            return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))]
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker        def out_function(name, *args, **kwargs):
867*da0073e9SAndroid Build Coastguard Worker            out_fn = getattr(torch, name)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker            def fn(tensor):
870*da0073e9SAndroid Build Coastguard Worker                result = torch.empty([0], dtype=tensor.dtype, device=tensor.device)
871*da0073e9SAndroid Build Coastguard Worker                out_fn(tensor, *args, out=result, **kwargs)
872*da0073e9SAndroid Build Coastguard Worker                return result
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker            return [Function(name + '_out', fn)]
875*da0073e9SAndroid Build Coastguard Worker
876*da0073e9SAndroid Build Coastguard Worker        def fn_method_and_inplace(name, *args, **kwargs):
877*da0073e9SAndroid Build Coastguard Worker            return (
878*da0073e9SAndroid Build Coastguard Worker                method(name, *args, **kwargs) +
879*da0073e9SAndroid Build Coastguard Worker                method(name + '_', *args, **kwargs) +
880*da0073e9SAndroid Build Coastguard Worker                out_function(name, *args, **kwargs)
881*da0073e9SAndroid Build Coastguard Worker            )
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker        # All of these operate on 2x2 tensors.
884*da0073e9SAndroid Build Coastguard Worker        tests = [
885*da0073e9SAndroid Build Coastguard Worker            # unary pointwise
886*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('abs'),
887*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('acos'),
888*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('asin'),
889*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('atan'),
890*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('ceil'),
891*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('clamp', -1, 1),
892*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('clamp_min', -2),
893*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('clamp_max', 2),
894*da0073e9SAndroid Build Coastguard Worker            method('cauchy_'),
895*da0073e9SAndroid Build Coastguard Worker            method('clone'),
896*da0073e9SAndroid Build Coastguard Worker            method('contiguous'),
897*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('cos'),
898*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('cosh'),
899*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('digamma'),
900*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('erf'),
901*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('erfc'),
902*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('erfinv'),
903*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('exp'),
904*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('expm1'),
905*da0073e9SAndroid Build Coastguard Worker            method('exponential_'),
906*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('floor'),
907*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('frac'),
908*da0073e9SAndroid Build Coastguard Worker            method('geometric_', p=0.5),
909*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('lgamma'),
910*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('log'),
911*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('log10'),
912*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('log1p'),
913*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('log2'),
914*da0073e9SAndroid Build Coastguard Worker            method('log_normal_'),
915*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('neg'),
916*da0073e9SAndroid Build Coastguard Worker            method('normal_'),
917*da0073e9SAndroid Build Coastguard Worker            [Function('polygamma', lambda t: torch.polygamma(1, t))],
918*da0073e9SAndroid Build Coastguard Worker            method('polygamma_', 1),
919*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('reciprocal'),
920*da0073e9SAndroid Build Coastguard Worker            method('random_', 0, 1),
921*da0073e9SAndroid Build Coastguard Worker            method('random_', 1),
922*da0073e9SAndroid Build Coastguard Worker            method('random_'),
923*da0073e9SAndroid Build Coastguard Worker            method('relu_'),
924*da0073e9SAndroid Build Coastguard Worker            method('requires_grad_'),
925*da0073e9SAndroid Build Coastguard Worker            method('relu'),
926*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('round'),
927*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('rsqrt'),
928*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('sigmoid'),
929*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('sign'),
930*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('sin'),
931*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('sinh'),
932*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('sqrt'),
933*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('tan'),
934*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('tanh'),
935*da0073e9SAndroid Build Coastguard Worker            fn('threshold', 0, 1),
936*da0073e9SAndroid Build Coastguard Worker            fn('threshold_', 0, 1),
937*da0073e9SAndroid Build Coastguard Worker            out_function('threshold', 0, 1),
938*da0073e9SAndroid Build Coastguard Worker            fn_method_and_inplace('trunc'),
939*da0073e9SAndroid Build Coastguard Worker            method('uniform_'),
940*da0073e9SAndroid Build Coastguard Worker            method('zero_'),
941*da0073e9SAndroid Build Coastguard Worker            method('fill_', 1),
942*da0073e9SAndroid Build Coastguard Worker            method('fill_', torch.tensor(3.14)),
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker            # conversions
945*da0073e9SAndroid Build Coastguard Worker            method('to', dtype=torch.long),
946*da0073e9SAndroid Build Coastguard Worker            method('to', device='cpu'),
947*da0073e9SAndroid Build Coastguard Worker            method('to', torch.empty([])),
948*da0073e9SAndroid Build Coastguard Worker            method('bool'),
949*da0073e9SAndroid Build Coastguard Worker            method('byte'),
950*da0073e9SAndroid Build Coastguard Worker            method('char'),
951*da0073e9SAndroid Build Coastguard Worker            method('cpu'),
952*da0073e9SAndroid Build Coastguard Worker            method('double'),
953*da0073e9SAndroid Build Coastguard Worker            method('float'),
954*da0073e9SAndroid Build Coastguard Worker            method('long'),
955*da0073e9SAndroid Build Coastguard Worker            method('half'),
956*da0073e9SAndroid Build Coastguard Worker            method('int'),
957*da0073e9SAndroid Build Coastguard Worker            method('short'),
958*da0073e9SAndroid Build Coastguard Worker            method('type', dtype=torch.long),
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker            # cumsum and cumprod
961*da0073e9SAndroid Build Coastguard Worker            fn('cumsum', 0),
962*da0073e9SAndroid Build Coastguard Worker            fn('cumsum', 'D'),
963*da0073e9SAndroid Build Coastguard Worker            out_function('cumsum', 'D'),
964*da0073e9SAndroid Build Coastguard Worker            fn('cumprod', 0),
965*da0073e9SAndroid Build Coastguard Worker            fn('cumprod', 'D'),
966*da0073e9SAndroid Build Coastguard Worker            out_function('cumprod', 'D'),
967*da0073e9SAndroid Build Coastguard Worker
968*da0073e9SAndroid Build Coastguard Worker            # views
969*da0073e9SAndroid Build Coastguard Worker            method('narrow', 0, 0, 1),
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker            # creation functions
972*da0073e9SAndroid Build Coastguard Worker            fn('empty_like'),
973*da0073e9SAndroid Build Coastguard Worker            fn('zeros_like'),
974*da0073e9SAndroid Build Coastguard Worker            fn('ones_like'),
975*da0073e9SAndroid Build Coastguard Worker            fn('full_like', 3.14),
976*da0073e9SAndroid Build Coastguard Worker            fn('rand_like'),
977*da0073e9SAndroid Build Coastguard Worker            fn('randn_like'),
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker            # bernoulli variants
980*da0073e9SAndroid Build Coastguard Worker            method('bernoulli_', 0.5),
981*da0073e9SAndroid Build Coastguard Worker            method('bernoulli_', torch.tensor(0.5)),
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker            method('softmax', dim=1),
984*da0073e9SAndroid Build Coastguard Worker            method('softmax', dim='D'),
985*da0073e9SAndroid Build Coastguard Worker            method('log_softmax', dim=1),
986*da0073e9SAndroid Build Coastguard Worker            method('log_softmax', dim='D'),
987*da0073e9SAndroid Build Coastguard Worker
988*da0073e9SAndroid Build Coastguard Worker            [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))],
989*da0073e9SAndroid Build Coastguard Worker            [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))],
990*da0073e9SAndroid Build Coastguard Worker        ]
991*da0073e9SAndroid Build Coastguard Worker        tests = flatten(tests)
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker        for testcase, device in itertools.product(tests, get_all_device_types()):
994*da0073e9SAndroid Build Coastguard Worker            _test(testcase, device=device)
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker    def test_cummax_cummin(self):
997*da0073e9SAndroid Build Coastguard Worker        def test_ops(op):
998*da0073e9SAndroid Build Coastguard Worker            for device in get_all_device_types():
999*da0073e9SAndroid Build Coastguard Worker                names = ('N', 'D')
1000*da0073e9SAndroid Build Coastguard Worker                tensor = torch.rand(2, 3, names=names)
1001*da0073e9SAndroid Build Coastguard Worker                result = op(tensor, 0)
1002*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result[0].names, names)
1003*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result[1].names, names)
1004*da0073e9SAndroid Build Coastguard Worker        test_ops(torch.cummax)
1005*da0073e9SAndroid Build Coastguard Worker        test_ops(torch.cummin)
1006*da0073e9SAndroid Build Coastguard Worker
1007*da0073e9SAndroid Build Coastguard Worker    def test_logcumsumexp(self):
1008*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1009*da0073e9SAndroid Build Coastguard Worker            names = ('N', 'D')
1010*da0073e9SAndroid Build Coastguard Worker            tensor = torch.rand(2, 3, names=names)
1011*da0073e9SAndroid Build Coastguard Worker            result = torch.logcumsumexp(tensor, 'D')
1012*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.names, names)
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker    def test_bitwise_not(self):
1015*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1016*da0073e9SAndroid Build Coastguard Worker            names = ('N', 'D')
1017*da0073e9SAndroid Build Coastguard Worker            tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
1018*da0073e9SAndroid Build Coastguard Worker            result = torch.empty(0, dtype=torch.bool)
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.bitwise_not().names, names)
1021*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.bitwise_not(tensor, out=result).names, names)
1022*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.bitwise_not_().names, names)
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker    def test_logical_not(self):
1025*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1026*da0073e9SAndroid Build Coastguard Worker            names = ('N', 'D')
1027*da0073e9SAndroid Build Coastguard Worker            tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
1028*da0073e9SAndroid Build Coastguard Worker            result = torch.empty(0, dtype=torch.bool)
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.logical_not().names, names)
1031*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.logical_not(tensor, out=result).names, names)
1032*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.logical_not_().names, names)
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker    def test_bernoulli(self):
1035*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1036*da0073e9SAndroid Build Coastguard Worker            names = ('N', 'D')
1037*da0073e9SAndroid Build Coastguard Worker            tensor = torch.rand(2, 3, names=names)
1038*da0073e9SAndroid Build Coastguard Worker            result = torch.empty(0)
1039*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.bernoulli().names, names)
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker            torch.bernoulli(tensor, out=result)
1042*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.names, names)
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker    def test_flatten(self):
1045*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W'))
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker        # basic
1048*da0073e9SAndroid Build Coastguard Worker        out = tensor.flatten('D', 'W', 'features')
1049*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.names, ['N', 'C', 'features'])
1050*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        # int overload
1053*da0073e9SAndroid Build Coastguard Worker        out = tensor.flatten(2, 4, 'features')
1054*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.names, ['N', 'C', 'features'])
1055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker        # list overload
1058*da0073e9SAndroid Build Coastguard Worker        out = tensor.flatten(['D', 'H', 'W'], 'features')
1059*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.names, ['N', 'C', 'features'])
1060*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker        # Non-contiguous flatten: N and H are not "adjacent" in memory.
1063*da0073e9SAndroid Build Coastguard Worker        sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D'))
1064*da0073e9SAndroid Build Coastguard Worker        sentences = sentences.transpose('T', 'H')
1065*da0073e9SAndroid Build Coastguard Worker        out = sentences.flatten('N', 'H', 'N_H')
1066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.names, ['N_H', 'T', 'D'])
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"):
1069*da0073e9SAndroid Build Coastguard Worker            tensor.flatten(['D', 'L'], 'features')
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
1072*da0073e9SAndroid Build Coastguard Worker            tensor.flatten(['D', 'W'], 'features')
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
1075*da0073e9SAndroid Build Coastguard Worker            tensor.flatten(['H', 'D', 'W'], 'features')
1076*da0073e9SAndroid Build Coastguard Worker
1077*da0073e9SAndroid Build Coastguard Worker    def test_flatten_nodims(self):
1078*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty((2, 3))
1079*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
1080*da0073e9SAndroid Build Coastguard Worker            tensor.flatten((), 'abcd')
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Worker    def test_flatten_index_error(self):
1083*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(1, 2)
1084*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError,
1085*da0073e9SAndroid Build Coastguard Worker                                    r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
1086*da0073e9SAndroid Build Coastguard Worker            tensor.flatten(0, 2)
1087*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError,
1088*da0073e9SAndroid Build Coastguard Worker                                    r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
1089*da0073e9SAndroid Build Coastguard Worker            tensor.flatten(0, 2, 'N')
1090*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
1091*da0073e9SAndroid Build Coastguard Worker                                    r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
1092*da0073e9SAndroid Build Coastguard Worker            tensor.flatten(1, 0)
1093*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
1094*da0073e9SAndroid Build Coastguard Worker                                    r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
1095*da0073e9SAndroid Build Coastguard Worker            tensor.flatten(1, 0, 'N')
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker    def test_unflatten(self):
1098*da0073e9SAndroid Build Coastguard Worker        # test args: tensor, int, namedshape
1099*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(
1100*da0073e9SAndroid Build Coastguard Worker            torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))),
1101*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, names=('A', 'B'))))
1102*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(
1103*da0073e9SAndroid Build Coastguard Worker            torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]),
1104*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, names=('A', 'B'))))
1105*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(
1106*da0073e9SAndroid Build Coastguard Worker            torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])),
1107*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, names=('A', 'B'))))
1108*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(
1109*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)),
1110*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 10, names=('A', 'B1'))))
1111*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(
1112*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B'))
1113*da0073e9SAndroid Build Coastguard Worker                 .unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])),
1114*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4'))))
1115*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(
1116*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 0, names=('A', 'B'))
1117*da0073e9SAndroid Build Coastguard Worker                 .unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])),
1118*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3'))))
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker        # test args: namedtensor, str, namedshape
1121*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.equal(
1122*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))),
1123*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, 2, names=('A', 'B1', 'B2'))))
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker        # test invalid args: namedtensor, str, sizes
1126*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"):
1127*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1], names=('A',)).unflatten('A', (1, 1))
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker        # test invalid args: namedtensor, int, sizes
1130*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r"input is a named tensor but no names were given for unflattened sizes"):
1131*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1], names=("A",)).unflatten(0, (1, 1))
1132*da0073e9SAndroid Build Coastguard Worker
1133*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
1134*da0073e9SAndroid Build Coastguard Worker                                    r"Provided sizes \[3, -1\] don't multiply up to the "
1135*da0073e9SAndroid Build Coastguard Worker                                    r"size of dim 1 \('B': 4\) in Tensor\['A', 'B'\]"):
1136*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 3), ('B2', -1)))
1137*da0073e9SAndroid Build Coastguard Worker
1138*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
1139*da0073e9SAndroid Build Coastguard Worker                                    r"the unspecified dimension size -1 can be any value and is ambiguous"):
1140*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 0, names=('A', 'B')).unflatten('B', (('B1', 0), ('B2', -1)))
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K'))
1143*da0073e9SAndroid Build Coastguard Worker
1144*da0073e9SAndroid Build Coastguard Worker        # accepts OrderedDict
1145*da0073e9SAndroid Build Coastguard Worker        out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5))))
1146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K'))
1147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.shape, (7, 2, 3, 5, 11))
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker        # Unflatten left-most
1150*da0073e9SAndroid Build Coastguard Worker        out = tensor.unflatten('N', (('N', 7), ('H', 1)))
1151*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.names, ('N', 'H', 'D', 'K'))
1152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11))
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        # Unflatten right-most
1155*da0073e9SAndroid Build Coastguard Worker        out = tensor.unflatten('K', (('K', 11), ('H', 1)))
1156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.names, ('N', 'D', 'K', 'H'))
1157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1))
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "don't multiply up to"):
1160*da0073e9SAndroid Build Coastguard Worker            tensor.unflatten('D', (('H', 3), ('W', 5)))
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'sizes must be non-empty'):
1163*da0073e9SAndroid Build Coastguard Worker            tensor.unflatten('D', None)
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'non-empty'):
1166*da0073e9SAndroid Build Coastguard Worker            tensor.unflatten('D', OrderedDict())
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker    def test_unsupported_op_error_msg(self):
1169*da0073e9SAndroid Build Coastguard Worker        named = torch.randn(3, 3, names=('N', 'C'))
1170*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1171*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"pdist.+is not yet supported with named tensors"):
1172*da0073e9SAndroid Build Coastguard Worker            torch.pdist(named)
1173*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1174*da0073e9SAndroid Build Coastguard Worker                RuntimeError, r"as_strided_.+is not yet supported with named tensors"):
1175*da0073e9SAndroid Build Coastguard Worker            named.as_strided_((3, 3), (3, 1))
1176*da0073e9SAndroid Build Coastguard Worker
1177*da0073e9SAndroid Build Coastguard Worker    def test_reduction_fns(self):
1178*da0073e9SAndroid Build Coastguard Worker        def check_output(output, expected_names):
1179*da0073e9SAndroid Build Coastguard Worker            if isinstance(output, torch.Tensor):
1180*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output.names, expected_names)
1181*da0073e9SAndroid Build Coastguard Worker                return
1182*da0073e9SAndroid Build Coastguard Worker            for out in output:
1183*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.names, expected_names)
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        def sum_all_outputs(output):
1186*da0073e9SAndroid Build Coastguard Worker            if isinstance(output, torch.Tensor):
1187*da0073e9SAndroid Build Coastguard Worker                return output.sum()
1188*da0073e9SAndroid Build Coastguard Worker            result = 0
1189*da0073e9SAndroid Build Coastguard Worker            for out in output:
1190*da0073e9SAndroid Build Coastguard Worker                result = out + result
1191*da0073e9SAndroid Build Coastguard Worker            return result.sum()
1192*da0073e9SAndroid Build Coastguard Worker
1193*da0073e9SAndroid Build Coastguard Worker        def test_simple_reduce(op, device):
1194*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1195*da0073e9SAndroid Build Coastguard Worker            check_output(op(t, 1), ['N', 'L'])
1196*da0073e9SAndroid Build Coastguard Worker            check_output(op(t, -1), ['N', 'C'])
1197*da0073e9SAndroid Build Coastguard Worker            check_output(op(t, 'C'), ['N', 'L'])
1198*da0073e9SAndroid Build Coastguard Worker            ops_support_dim_none = [
1199*da0073e9SAndroid Build Coastguard Worker                'sum',
1200*da0073e9SAndroid Build Coastguard Worker                'mean',
1201*da0073e9SAndroid Build Coastguard Worker                'std',
1202*da0073e9SAndroid Build Coastguard Worker                'var',
1203*da0073e9SAndroid Build Coastguard Worker                'std_mean',
1204*da0073e9SAndroid Build Coastguard Worker                'var_mean',
1205*da0073e9SAndroid Build Coastguard Worker                'nanmean',
1206*da0073e9SAndroid Build Coastguard Worker                'nansum',
1207*da0073e9SAndroid Build Coastguard Worker            ]
1208*da0073e9SAndroid Build Coastguard Worker            if op.__name__ in ops_support_dim_none:
1209*da0073e9SAndroid Build Coastguard Worker                check_output(op(t, None), [])
1210*da0073e9SAndroid Build Coastguard Worker            else:
1211*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
1212*da0073e9SAndroid Build Coastguard Worker                    op(t, None)
1213*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):
1214*da0073e9SAndroid Build Coastguard Worker                op(t, 'H')
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        def test_autograd_supports_dimname_overload(op, device):
1217*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device, requires_grad=True)
1218*da0073e9SAndroid Build Coastguard Worker            sum_all_outputs(op(t, 'C')).backward()
1219*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(t.grad)
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker        def test_complete_reduce(op, device):
1222*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1223*da0073e9SAndroid Build Coastguard Worker            check_output(op(t), [])
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker        def test_multidim_reduce(op, device):
1226*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker            check_output(op(t, [1, 2]), ['N'])
1229*da0073e9SAndroid Build Coastguard Worker            check_output(op(t, [0, -1]), ['C'])
1230*da0073e9SAndroid Build Coastguard Worker            check_output(op(t, ['C', 'L']), ['N'])
1231*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
1232*da0073e9SAndroid Build Coastguard Worker                op(t, [None, 'C'])
1233*da0073e9SAndroid Build Coastguard Worker
1234*da0073e9SAndroid Build Coastguard Worker        def test_out_variant(op, output_lambda, device):
1235*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1236*da0073e9SAndroid Build Coastguard Worker            if output_lambda:
1237*da0073e9SAndroid Build Coastguard Worker                out = output_lambda(t)
1238*da0073e9SAndroid Build Coastguard Worker            else:
1239*da0073e9SAndroid Build Coastguard Worker                out = torch.empty([0], device=device)
1240*da0073e9SAndroid Build Coastguard Worker            op(t, 'C', out=out)
1241*da0073e9SAndroid Build Coastguard Worker            check_output(out, ['N', 'L'])
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker        def test_keepdim(op, device):
1244*da0073e9SAndroid Build Coastguard Worker            t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
1245*da0073e9SAndroid Build Coastguard Worker            check_output(op(t, 'C', keepdim=True), ['N', 'C', 'L'])
1246*da0073e9SAndroid Build Coastguard Worker
1247*da0073e9SAndroid Build Coastguard Worker        def values_and_indices(t):
1248*da0073e9SAndroid Build Coastguard Worker            return (torch.empty([0], device=t.device),
1249*da0073e9SAndroid Build Coastguard Worker                    torch.empty([0], device=t.device, dtype=torch.long))
1250*da0073e9SAndroid Build Coastguard Worker
1251*da0073e9SAndroid Build Coastguard Worker        def kthvalue_wrapper(tensor, *args, **kwargs):
1252*da0073e9SAndroid Build Coastguard Worker            # Return the 0-th value
1253*da0073e9SAndroid Build Coastguard Worker            return torch.kthvalue(tensor, 1, *args, **kwargs)
1254*da0073e9SAndroid Build Coastguard Worker
1255*da0073e9SAndroid Build Coastguard Worker        Case = namedtuple('Case', [
1256*da0073e9SAndroid Build Coastguard Worker            'op',
1257*da0073e9SAndroid Build Coastguard Worker            'supports_complete_reduce',
1258*da0073e9SAndroid Build Coastguard Worker            'supports_multidim_reduce',
1259*da0073e9SAndroid Build Coastguard Worker            'supports_out_variant',
1260*da0073e9SAndroid Build Coastguard Worker            'supports_keepdim',
1261*da0073e9SAndroid Build Coastguard Worker            'output_lambda',
1262*da0073e9SAndroid Build Coastguard Worker        ])
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker        tests = [
1265*da0073e9SAndroid Build Coastguard Worker            Case(torch.sum, True, True, True, True, None),
1266*da0073e9SAndroid Build Coastguard Worker            Case(torch.prod, True, False, True, True, None),
1267*da0073e9SAndroid Build Coastguard Worker            Case(torch.mean, True, True, True, True, None),
1268*da0073e9SAndroid Build Coastguard Worker            Case(torch.var, True, True, True, True, None),
1269*da0073e9SAndroid Build Coastguard Worker            Case(torch.std, True, True, True, True, None),
1270*da0073e9SAndroid Build Coastguard Worker            Case(torch.std_mean, True, True, False, True, None),
1271*da0073e9SAndroid Build Coastguard Worker            Case(torch.var_mean, True, True, False, True, None),
1272*da0073e9SAndroid Build Coastguard Worker            Case(torch.min, True, False, True, True, values_and_indices),
1273*da0073e9SAndroid Build Coastguard Worker            Case(torch.max, True, False, True, True, values_and_indices),
1274*da0073e9SAndroid Build Coastguard Worker            Case(torch.unbind, False, False, False, False, None),
1275*da0073e9SAndroid Build Coastguard Worker            Case(torch.logsumexp, False, True, True, True, None),
1276*da0073e9SAndroid Build Coastguard Worker            Case(torch.mode, False, False, True, True, values_and_indices),
1277*da0073e9SAndroid Build Coastguard Worker            Case(kthvalue_wrapper, False, False, True, True, values_and_indices),
1278*da0073e9SAndroid Build Coastguard Worker            Case(torch.median, True, False, True, True, values_and_indices),
1279*da0073e9SAndroid Build Coastguard Worker            Case(torch.nanmedian, True, False, True, True, values_and_indices),
1280*da0073e9SAndroid Build Coastguard Worker        ]
1281*da0073e9SAndroid Build Coastguard Worker
1282*da0073e9SAndroid Build Coastguard Worker        for testcase, device in itertools.product(tests, get_all_device_types()):
1283*da0073e9SAndroid Build Coastguard Worker            op = testcase.op
1284*da0073e9SAndroid Build Coastguard Worker            test_simple_reduce(op, device)
1285*da0073e9SAndroid Build Coastguard Worker            test_autograd_supports_dimname_overload(op, device)
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker            if testcase.supports_keepdim:
1288*da0073e9SAndroid Build Coastguard Worker                test_keepdim(op, device)
1289*da0073e9SAndroid Build Coastguard Worker            if testcase.supports_out_variant:
1290*da0073e9SAndroid Build Coastguard Worker                test_out_variant(op, testcase.output_lambda, device)
1291*da0073e9SAndroid Build Coastguard Worker            if testcase.supports_complete_reduce:
1292*da0073e9SAndroid Build Coastguard Worker                test_complete_reduce(op, device)
1293*da0073e9SAndroid Build Coastguard Worker            if testcase.supports_multidim_reduce:
1294*da0073e9SAndroid Build Coastguard Worker                test_multidim_reduce(op, device)
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker    def test_masked_select(self):
1297*da0073e9SAndroid Build Coastguard Worker        # simple
1298*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1299*da0073e9SAndroid Build Coastguard Worker            torch.masked_select,
1300*da0073e9SAndroid Build Coastguard Worker            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
1301*da0073e9SAndroid Build Coastguard Worker            expected_names=[None])
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        # left broadcast
1304*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1305*da0073e9SAndroid Build Coastguard Worker            torch.masked_select,
1306*da0073e9SAndroid Build Coastguard Worker            (create('C:3'), (create('2,3') > 0).rename('N', 'C')),
1307*da0073e9SAndroid Build Coastguard Worker            expected_names=[None])
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker        # right broadcast
1310*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1311*da0073e9SAndroid Build Coastguard Worker            torch.masked_select,
1312*da0073e9SAndroid Build Coastguard Worker            (create('N:2,C:3'), (create('3') > 0).rename('C')),
1313*da0073e9SAndroid Build Coastguard Worker            expected_names=[None])
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker        # error
1316*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1317*da0073e9SAndroid Build Coastguard Worker            torch.masked_select,
1318*da0073e9SAndroid Build Coastguard Worker            (create('N:2,C:3'), (create('3') > 0).rename('D')),
1319*da0073e9SAndroid Build Coastguard Worker            maybe_raises_regex='do not match')
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker        # out=
1322*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1323*da0073e9SAndroid Build Coastguard Worker            out_fn(torch.masked_select),
1324*da0073e9SAndroid Build Coastguard Worker            (create('0'), create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
1325*da0073e9SAndroid Build Coastguard Worker            expected_names=[None])
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
1328*da0073e9SAndroid Build Coastguard Worker        # simple
1329*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1330*da0073e9SAndroid Build Coastguard Worker            torch.cat,
1331*da0073e9SAndroid Build Coastguard Worker            [[create('N:2,C:3'), create('N:2,C:3')]],
1332*da0073e9SAndroid Build Coastguard Worker            expected_names=['N', 'C'])
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker        # error: zero dim
1335*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1336*da0073e9SAndroid Build Coastguard Worker            torch.cat,
1337*da0073e9SAndroid Build Coastguard Worker            [[create(''), create('')]],
1338*da0073e9SAndroid Build Coastguard Worker            maybe_raises_regex='zero-dim')
1339*da0073e9SAndroid Build Coastguard Worker
1340*da0073e9SAndroid Build Coastguard Worker        # error: names don't match
1341*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1342*da0073e9SAndroid Build Coastguard Worker            torch.cat,
1343*da0073e9SAndroid Build Coastguard Worker            [[create('N:2,C:3'), create('C:3,N:2')]],
1344*da0073e9SAndroid Build Coastguard Worker            maybe_raises_regex='do not match')
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker        # error: different number of dims
1347*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1348*da0073e9SAndroid Build Coastguard Worker            torch.cat,
1349*da0073e9SAndroid Build Coastguard Worker            [[create('N:2,C:3'), create('C:3')]],
1350*da0073e9SAndroid Build Coastguard Worker            maybe_raises_regex='must have same number of dimensions')
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker        # out=
1353*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1354*da0073e9SAndroid Build Coastguard Worker            out_fn(torch.cat),
1355*da0073e9SAndroid Build Coastguard Worker            [create('0'), [create('N:2,C:3'), create('N:2,C:3')]],
1356*da0073e9SAndroid Build Coastguard Worker            expected_names=['N', 'C'])
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    def test_masked_fill(self):
1359*da0073e9SAndroid Build Coastguard Worker        # simple
1360*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1361*da0073e9SAndroid Build Coastguard Worker            Tensor.masked_fill,
1362*da0073e9SAndroid Build Coastguard Worker            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1363*da0073e9SAndroid Build Coastguard Worker            expected_names=['N', 'C'])
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker        # left broadcast
1366*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1367*da0073e9SAndroid Build Coastguard Worker            Tensor.masked_fill,
1368*da0073e9SAndroid Build Coastguard Worker            (create('C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1369*da0073e9SAndroid Build Coastguard Worker            maybe_raises_regex="must be less than or equal to")
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker        # right broadcast
1372*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1373*da0073e9SAndroid Build Coastguard Worker            Tensor.masked_fill,
1374*da0073e9SAndroid Build Coastguard Worker            (create('N:2,C:3'), (create('3') > 0).rename('C'), 3.14),
1375*da0073e9SAndroid Build Coastguard Worker            expected_names=['N', 'C'])
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker        # error
1378*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1379*da0073e9SAndroid Build Coastguard Worker            Tensor.masked_fill,
1380*da0073e9SAndroid Build Coastguard Worker            (create('N:2,C:3'), (create('3') > 0).rename('D'), 3.14),
1381*da0073e9SAndroid Build Coastguard Worker            maybe_raises_regex='do not match')
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker        # inplace
1384*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1385*da0073e9SAndroid Build Coastguard Worker            Tensor.masked_fill_,
1386*da0073e9SAndroid Build Coastguard Worker            (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1387*da0073e9SAndroid Build Coastguard Worker            expected_names=['N', 'C'])
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker        # inplace, computed names don't match output tensor names
1390*da0073e9SAndroid Build Coastguard Worker        self._test_name_inference(
1391*da0073e9SAndroid Build Coastguard Worker            Tensor.masked_fill_,
1392*da0073e9SAndroid Build Coastguard Worker            (create('N:2,None:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
1393*da0073e9SAndroid Build Coastguard Worker            maybe_raises_regex="not the same as the computed output names")
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker    def test_using_seen_interned_string_doesnt_bump_refcount(self):
1397*da0073e9SAndroid Build Coastguard Worker        def see_name():
1398*da0073e9SAndroid Build Coastguard Worker            seen_name = 'N'
1399*da0073e9SAndroid Build Coastguard Worker            pass_name_to_python_arg_parser(seen_name)
1400*da0073e9SAndroid Build Coastguard Worker
1401*da0073e9SAndroid Build Coastguard Worker        see_name()
1402*da0073e9SAndroid Build Coastguard Worker        seen_name = 'N'
1403*da0073e9SAndroid Build Coastguard Worker        old_refcnt = sys.getrefcount(seen_name)
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker        pass_name_to_python_arg_parser(seen_name)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker        new_refcnt = sys.getrefcount(seen_name)
1408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_refcnt, old_refcnt)
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker    # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464
1411*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+")
1412*da0073e9SAndroid Build Coastguard Worker    def test_using_unseen_interned_string_bumps_refcount_permanently(self):
1413*da0073e9SAndroid Build Coastguard Worker        # Please don't use this as a name in a different test.
1414*da0073e9SAndroid Build Coastguard Worker        unseen_name = 'abcdefghi'
1415*da0073e9SAndroid Build Coastguard Worker        old_refcnt = sys.getrefcount(unseen_name)
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker        pass_name_to_python_arg_parser(unseen_name)
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker        new_refcnt = sys.getrefcount(unseen_name)
1420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_refcnt, old_refcnt + 1)
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Worker    # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464
1423*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+")
1424*da0073e9SAndroid Build Coastguard Worker    def test_using_unseen_uninterned_string_refcounts(self):
1425*da0073e9SAndroid Build Coastguard Worker        # Please don't use this as a name in a different test.
1426*da0073e9SAndroid Build Coastguard Worker        # non-compile-time constants are not interned
1427*da0073e9SAndroid Build Coastguard Worker        unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl'])
1428*da0073e9SAndroid Build Coastguard Worker        interned_unseen_name = 'abcdefghijkl'
1429*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(unseen_name is interned_unseen_name)
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker        old_uninterned_refcnt = sys.getrefcount(unseen_name)
1432*da0073e9SAndroid Build Coastguard Worker        old_interned_refcnt = sys.getrefcount(interned_unseen_name)
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker        pass_name_to_python_arg_parser(unseen_name)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        new_uninterned_refcnt = sys.getrefcount(unseen_name)
1437*da0073e9SAndroid Build Coastguard Worker        new_interned_refcnt = sys.getrefcount(interned_unseen_name)
1438*da0073e9SAndroid Build Coastguard Worker
1439*da0073e9SAndroid Build Coastguard Worker        # Internally, PyTorch should not hold a reference to the uninterned string
1440*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt)
1441*da0073e9SAndroid Build Coastguard Worker
1442*da0073e9SAndroid Build Coastguard Worker        # Instead, we should hold a new reference to the interned version.
1443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1)
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker    def _test_select(self, device):
1446*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
1447*da0073e9SAndroid Build Coastguard Worker        y = x.select(1, 1)
1448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.names, ('N', 'H', 'W'))
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Worker        y = x.select('C', 1)
1451*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.names, ('N', 'H', 'W'))
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1454*da0073e9SAndroid Build Coastguard Worker                RuntimeError, 'Please look up dimensions by name'):
1455*da0073e9SAndroid Build Coastguard Worker            y = x.select(None, 1)
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    def test_select(self):
1458*da0073e9SAndroid Build Coastguard Worker        self._test_select('cpu')
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'no CUDA')
1461*da0073e9SAndroid Build Coastguard Worker    def test_select_cuda(self):
1462*da0073e9SAndroid Build Coastguard Worker        self._test_select('cuda')
1463*da0073e9SAndroid Build Coastguard Worker
1464*da0073e9SAndroid Build Coastguard Worker    def _test_as_strided(self, device):
1465*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
1466*da0073e9SAndroid Build Coastguard Worker        y = x.as_strided([2 * 3 * 4 * 5], [1])
1467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.names, (None,))
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker    def test_as_strided(self):
1470*da0073e9SAndroid Build Coastguard Worker        self._test_as_strided('cpu')
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, 'no CUDA')
1473*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_cuda(self):
1474*da0073e9SAndroid Build Coastguard Worker        self._test_as_strided('cuda')
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker    def test_no_jit_tracer_support(self):
1477*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1478*da0073e9SAndroid Build Coastguard Worker            return torch.full(x.shape, 2., names=('N',))
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
1481*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
1482*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(foo, example_inputs=x)
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker        def bar(x):
1485*da0073e9SAndroid Build Coastguard Worker            return x.select('N', 1)
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
1488*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3)
1489*da0073e9SAndroid Build Coastguard Worker            torch.jit.trace(bar, example_inputs=x)
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker    def test_no_jit_script_support(self):
1492*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1493*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1494*da0073e9SAndroid Build Coastguard Worker            return x + 1
1495*da0073e9SAndroid Build Coastguard Worker
1496*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'NYI'):
1497*da0073e9SAndroid Build Coastguard Worker            foo(torch.randn(2, 3, names=('N', 'C')))
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
1500*da0073e9SAndroid Build Coastguard Worker        def add_names(x):
1501*da0073e9SAndroid Build Coastguard Worker            x.names = ('N', 'C')
1502*da0073e9SAndroid Build Coastguard Worker
1503*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
1504*da0073e9SAndroid Build Coastguard Worker        def return_named_tensor(input):
1505*da0073e9SAndroid Build Coastguard Worker            add_names(input)
1506*da0073e9SAndroid Build Coastguard Worker            return input
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "NYI"):
1509*da0073e9SAndroid Build Coastguard Worker            return_named_tensor(torch.randn(1, 1))
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker    def test_align_to(self):
1512*da0073e9SAndroid Build Coastguard Worker        # trivial
1513*da0073e9SAndroid Build Coastguard Worker        tensor = create('N:3')
1514*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('N')
1515*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N'])
1516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [3])
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker        # unsqueeze behavior
1519*da0073e9SAndroid Build Coastguard Worker        tensor = create('N:3')
1520*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('N', 'D')
1521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N', 'D'])
1522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [3, 1])
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        # transpose behavior
1525*da0073e9SAndroid Build Coastguard Worker        tensor = create('N:3,C:2')
1526*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('C', 'N')
1527*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['C', 'N'])
1528*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [2, 3])
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker        # unsqueeze / transpose
1531*da0073e9SAndroid Build Coastguard Worker        tensor = create('C:2,N:3,H:5')
1532*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('N', 'H', 'W', 'C')
1533*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
1534*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [3, 5, 1, 2])
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker        # All input dimensions must be named
1537*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"):
1538*da0073e9SAndroid Build Coastguard Worker            create('None:2,C:3').align_to('N', 'C')
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker        # not enough names
1541*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"):
1542*da0073e9SAndroid Build Coastguard Worker            create('N:2,C:3').align_to('C')
1543*da0073e9SAndroid Build Coastguard Worker
1544*da0073e9SAndroid Build Coastguard Worker        # names not found
1545*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"):
1546*da0073e9SAndroid Build Coastguard Worker            create('N:2,C:3').align_to('D', 'N')
1547*da0073e9SAndroid Build Coastguard Worker
1548*da0073e9SAndroid Build Coastguard Worker    def test_align_to_ellipsis(self):
1549*da0073e9SAndroid Build Coastguard Worker        tensor = create('N:7,H:3,W:5,C:2')
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker        # ... = ['N', 'H', 'W', 'C']
1552*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('...')
1553*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
1554*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [7, 3, 5, 2])
1555*da0073e9SAndroid Build Coastguard Worker
1556*da0073e9SAndroid Build Coastguard Worker        # ... = ['H', 'C']
1557*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('...', 'W', 'N')
1558*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['H', 'C', 'W', 'N'])
1559*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [3, 2, 5, 7])
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker        # ... = ['N', 'W']
1562*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('H', 'C', '...')
1563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['H', 'C', 'N', 'W'])
1564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [3, 2, 7, 5])
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker        # ... = ['H', 'C']
1567*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('W', '...', 'N')
1568*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['W', 'H', 'C', 'N'])
1569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [5, 3, 2, 7])
1570*da0073e9SAndroid Build Coastguard Worker
1571*da0073e9SAndroid Build Coastguard Worker        # ... = []
1572*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_to('N', '...', 'C', 'D', 'H', 'W')
1573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N', 'C', 'D', 'H', 'W'])
1574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [7, 2, 1, 3, 5])
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker        # Input tensor partially named
1577*da0073e9SAndroid Build Coastguard Worker        partially_named = create('None:2,None:3,None:5,C:7')
1578*da0073e9SAndroid Build Coastguard Worker        output = partially_named.align_to('C', '...')
1579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['C', None, None, None])
1580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [7, 2, 3, 5])
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"):
1583*da0073e9SAndroid Build Coastguard Worker            partially_named.align_to('C', None, '...')
1584*da0073e9SAndroid Build Coastguard Worker
1585*da0073e9SAndroid Build Coastguard Worker        # Input order partially named
1586*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"):
1587*da0073e9SAndroid Build Coastguard Worker            tensor.align_to('...', 'N', None)
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker        # Input order duplicate names
1590*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "duplicate names"):
1591*da0073e9SAndroid Build Coastguard Worker            tensor.align_to('...', 'N', 'N')
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker    def test_align_as(self):
1594*da0073e9SAndroid Build Coastguard Worker        # align_as calls align_to internally. align_to has pretty substantial tests,
1595*da0073e9SAndroid Build Coastguard Worker        # so just test some basic things here.
1596*da0073e9SAndroid Build Coastguard Worker        tensor = create('C:2,N:3,H:5')
1597*da0073e9SAndroid Build Coastguard Worker        other = create('N:1,H:1,W:1,C:1')
1598*da0073e9SAndroid Build Coastguard Worker        output = tensor.align_as(other)
1599*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
1600*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, [3, 5, 1, 2])
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Not implemented yet")
1603*da0073e9SAndroid Build Coastguard Worker    def test_align_tensors_two_inputs(self):
1604*da0073e9SAndroid Build Coastguard Worker        def _test(tensor_namedshape, align_names, expected_sizes, expected_error):
1605*da0073e9SAndroid Build Coastguard Worker            tensor_names, tensor_sizes = tensor_namedshape
1606*da0073e9SAndroid Build Coastguard Worker            tensor = torch.empty(*tensor_sizes, names=tensor_names)
1607*da0073e9SAndroid Build Coastguard Worker            other = torch.empty([1] * len(align_names), names=align_names)
1608*da0073e9SAndroid Build Coastguard Worker            if expected_error is not None:
1609*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, expected_error):
1610*da0073e9SAndroid Build Coastguard Worker                    torch.align_tensors(tensor, other)
1611*da0073e9SAndroid Build Coastguard Worker                return
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker            output, _ = torch.align_tensors(tensor, other)
1614*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output.shape, expected_sizes)
1615*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output.names, align_names)
1616*da0073e9SAndroid Build Coastguard Worker
1617*da0073e9SAndroid Build Coastguard Worker        Case = namedtuple('Case', [
1618*da0073e9SAndroid Build Coastguard Worker            'tensor_namedshape',
1619*da0073e9SAndroid Build Coastguard Worker            'align_names',
1620*da0073e9SAndroid Build Coastguard Worker            'expected_sizes',
1621*da0073e9SAndroid Build Coastguard Worker            'expected_error',
1622*da0073e9SAndroid Build Coastguard Worker        ])
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        tests = [
1625*da0073e9SAndroid Build Coastguard Worker            # basic tests
1626*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=(['C'], [2]),
1627*da0073e9SAndroid Build Coastguard Worker                 align_names=['C'],
1628*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[2],
1629*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1630*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=(['C'], [2]),
1631*da0073e9SAndroid Build Coastguard Worker                 align_names=['D'],
1632*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=None,
1633*da0073e9SAndroid Build Coastguard Worker                 expected_error='not a subsequence'),
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker            # single-dim alignment test
1636*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=(['C'], [2]),
1637*da0073e9SAndroid Build Coastguard Worker                 align_names=['N', 'C'],
1638*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[1, 2],
1639*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1640*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[['N'], [2]],
1641*da0073e9SAndroid Build Coastguard Worker                 align_names=['N', 'C'],
1642*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[2, 1],
1643*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker            # multiple dim alignment test
1646*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[['N', 'C'], [2, 3]],
1647*da0073e9SAndroid Build Coastguard Worker                 align_names=['N', 'H', 'C', 'W'],
1648*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[2, 1, 3, 1],
1649*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1650*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[['N', 'C'], [2, 3]],
1651*da0073e9SAndroid Build Coastguard Worker                 align_names=['C', 'H', 'N', 'W'],
1652*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=None,
1653*da0073e9SAndroid Build Coastguard Worker                 expected_error='not a subsequence'),
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker            # scalar tensor tests
1656*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[None, [[]]],
1657*da0073e9SAndroid Build Coastguard Worker                 align_names=['N', 'C'],
1658*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[1, 1],
1659*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1660*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[[], [[]]],
1661*da0073e9SAndroid Build Coastguard Worker                 align_names=[None, None],
1662*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[1, 1],
1663*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker            # unnamed tensor tests
1666*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[None, [2, 3]],
1667*da0073e9SAndroid Build Coastguard Worker                 align_names=[None, None],
1668*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[2, 3],
1669*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1670*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[None, [2, 3]],
1671*da0073e9SAndroid Build Coastguard Worker                 align_names=[None, None, None],
1672*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[1, 2, 3],
1673*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1674*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[None, [2]],
1675*da0073e9SAndroid Build Coastguard Worker                 align_names=['N'],
1676*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=None,
1677*da0073e9SAndroid Build Coastguard Worker                 expected_error='not a subsequence'),
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker            # unnamed dim alignment tests
1680*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[[None], [2]],
1681*da0073e9SAndroid Build Coastguard Worker                 align_names=['N', None],
1682*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[1, 2],
1683*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1684*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[[None], [2]],
1685*da0073e9SAndroid Build Coastguard Worker                 align_names=['N', None, None, None],
1686*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[1, 1, 1, 2],
1687*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1688*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[['N'], [2]],
1689*da0073e9SAndroid Build Coastguard Worker                 align_names=['N', None, None, None],
1690*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[2, 1, 1, 1],
1691*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1692*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[[None, 'N', None], [2, 3, 5]],
1693*da0073e9SAndroid Build Coastguard Worker                 align_names=[None, None, 'N', None],
1694*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=[1, 2, 3, 5],
1695*da0073e9SAndroid Build Coastguard Worker                 expected_error=None),
1696*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[[None], [2]],
1697*da0073e9SAndroid Build Coastguard Worker                 align_names=[None, 'N'],
1698*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=None,
1699*da0073e9SAndroid Build Coastguard Worker                 expected_error='absolute position from the right'),
1700*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[None, [2]],
1701*da0073e9SAndroid Build Coastguard Worker                 align_names=[None, 'N'],
1702*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=None,
1703*da0073e9SAndroid Build Coastguard Worker                 expected_error='absolute position from the right'),
1704*da0073e9SAndroid Build Coastguard Worker            Case(tensor_namedshape=[[None, 'N'], [2, 3]],
1705*da0073e9SAndroid Build Coastguard Worker                 align_names=[None, 'C', 'N'],
1706*da0073e9SAndroid Build Coastguard Worker                 expected_sizes=None,
1707*da0073e9SAndroid Build Coastguard Worker                 expected_error='absolute position from the right'),
1708*da0073e9SAndroid Build Coastguard Worker        ]
1709*da0073e9SAndroid Build Coastguard Worker
1710*da0073e9SAndroid Build Coastguard Worker        for test in tests:
1711*da0073e9SAndroid Build Coastguard Worker            _test(*test)
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("Not implemented yet")
1714*da0073e9SAndroid Build Coastguard Worker    def test_align_tensors(self):
1715*da0073e9SAndroid Build Coastguard Worker        def reference_fn(*tensors):
1716*da0073e9SAndroid Build Coastguard Worker            longest_names = tensors[0].names
1717*da0073e9SAndroid Build Coastguard Worker            for tensor in tensors:
1718*da0073e9SAndroid Build Coastguard Worker                if len(tensor.names) > len(longest_names):
1719*da0073e9SAndroid Build Coastguard Worker                    longest_names = tensor.names
1720*da0073e9SAndroid Build Coastguard Worker            return [tensor.align_to(*longest_names) for tensor in tensors]
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(1, 1, names=('N', 'H'))
1723*da0073e9SAndroid Build Coastguard Worker        y = torch.empty(2, 3, 5, names=('N', 'C', 'H'))
1724*da0073e9SAndroid Build Coastguard Worker        z = torch.empty(2, names=('N',))
1725*da0073e9SAndroid Build Coastguard Worker        output = torch.align_tensors(x, y, z)
1726*da0073e9SAndroid Build Coastguard Worker        expected_tensors = reference_fn(x, y, z)
1727*da0073e9SAndroid Build Coastguard Worker        for tensor, expected in zip(output, expected_tensors):
1728*da0073e9SAndroid Build Coastguard Worker            self.assertTensorDataAndNamesEqual(tensor, expected)
1729*da0073e9SAndroid Build Coastguard Worker
1730*da0073e9SAndroid Build Coastguard Worker    def test_mm(self):
1731*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1732*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1733*da0073e9SAndroid Build Coastguard Worker                torch.mm, device=device,
1734*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,C:2'), create('W:2,H:5')),
1735*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'H'))
1736*da0073e9SAndroid Build Coastguard Worker
1737*da0073e9SAndroid Build Coastguard Worker            # left arg is unnamed
1738*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1739*da0073e9SAndroid Build Coastguard Worker                torch.mm, device=device,
1740*da0073e9SAndroid Build Coastguard Worker                args=(create('3,2'), create('W:2,H:5')),
1741*da0073e9SAndroid Build Coastguard Worker                expected_names=(None, 'H'))
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker            # right arg is unnamed
1744*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1745*da0073e9SAndroid Build Coastguard Worker                torch.mm, device=device,
1746*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,C:2'), create('2,5')),
1747*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', None))
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker            # out=
1750*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1751*da0073e9SAndroid Build Coastguard Worker                out_fn(torch.mm), device=device,
1752*da0073e9SAndroid Build Coastguard Worker                args=(create('0'), create('N:3,C:2'), create('W:2,H:5')),
1753*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'H'))
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1756*da0073e9SAndroid Build Coastguard Worker                torch.mm, device=device,
1757*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,C:2'), create('W:2,N:5')),
1758*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='with duplicate names')
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker    def test_expand(self):
1761*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1762*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1763*da0073e9SAndroid Build Coastguard Worker                Tensor.expand, device=device,
1764*da0073e9SAndroid Build Coastguard Worker                args=(create('D:1'), [3]), expected_names=('D',))
1765*da0073e9SAndroid Build Coastguard Worker
1766*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1767*da0073e9SAndroid Build Coastguard Worker                Tensor.expand, device=device,
1768*da0073e9SAndroid Build Coastguard Worker                args=(create('H:3,W:2'), [10, 3, 3, 2]),
1769*da0073e9SAndroid Build Coastguard Worker                expected_names=(None, None, 'H', 'W'))
1770*da0073e9SAndroid Build Coastguard Worker
1771*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1772*da0073e9SAndroid Build Coastguard Worker                Tensor.expand, device=device,
1773*da0073e9SAndroid Build Coastguard Worker                args=(create('3, 2'), [10, 3, 3, 2]),
1774*da0073e9SAndroid Build Coastguard Worker                expected_names=(None, None, None, None))
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker    def test_addmm(self):
1777*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1778*da0073e9SAndroid Build Coastguard Worker            # full names
1779*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1780*da0073e9SAndroid Build Coastguard Worker                torch.addmm, device=device,
1781*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
1782*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'H'))
1783*da0073e9SAndroid Build Coastguard Worker
1784*da0073e9SAndroid Build Coastguard Worker            # no name on bias
1785*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1786*da0073e9SAndroid Build Coastguard Worker                torch.addmm, device=device,
1787*da0073e9SAndroid Build Coastguard Worker                args=(create('3,5'), create('N:3,C:2'), create('W:2,H:5')),
1788*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'H'))
1789*da0073e9SAndroid Build Coastguard Worker
1790*da0073e9SAndroid Build Coastguard Worker            # partially named bias
1791*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1792*da0073e9SAndroid Build Coastguard Worker                torch.addmm, device=device,
1793*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
1794*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'H'))
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker            # out=
1797*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1798*da0073e9SAndroid Build Coastguard Worker                out_fn(torch.addmm), device=device,
1799*da0073e9SAndroid Build Coastguard Worker                args=(create('0'), create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
1800*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'H'))
1801*da0073e9SAndroid Build Coastguard Worker
1802*da0073e9SAndroid Build Coastguard Worker            # inplace
1803*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1804*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.addmm_, device=device,
1805*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
1806*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'H'))
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1809*da0073e9SAndroid Build Coastguard Worker                torch.addmm, device=device,
1810*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,N:5')),
1811*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='with duplicate names')
1812*da0073e9SAndroid Build Coastguard Worker
1813*da0073e9SAndroid Build Coastguard Worker    def test_bmm(self):
1814*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1815*da0073e9SAndroid Build Coastguard Worker            # full names
1816*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1817*da0073e9SAndroid Build Coastguard Worker                torch.bmm, device=device,
1818*da0073e9SAndroid Build Coastguard Worker                args=(create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
1819*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'A', 'B'))
1820*da0073e9SAndroid Build Coastguard Worker
1821*da0073e9SAndroid Build Coastguard Worker            # no name on left tensor
1822*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1823*da0073e9SAndroid Build Coastguard Worker                torch.bmm, device=device,
1824*da0073e9SAndroid Build Coastguard Worker                args=(create('7,3,2'), create('N:7,A:2,B:5')),
1825*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', None, 'B'))
1826*da0073e9SAndroid Build Coastguard Worker
1827*da0073e9SAndroid Build Coastguard Worker            # no name on right tensor
1828*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1829*da0073e9SAndroid Build Coastguard Worker                torch.bmm, device=device,
1830*da0073e9SAndroid Build Coastguard Worker                args=(create('N:7,A:3,B:2'), create('7,2,5')),
1831*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'A', None))
1832*da0073e9SAndroid Build Coastguard Worker
1833*da0073e9SAndroid Build Coastguard Worker            # out=
1834*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1835*da0073e9SAndroid Build Coastguard Worker                out_fn(torch.bmm), device=device,
1836*da0073e9SAndroid Build Coastguard Worker                args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
1837*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'A', 'B'))
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker            # duplicate names after mm
1840*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1841*da0073e9SAndroid Build Coastguard Worker                torch.bmm, device=device,
1842*da0073e9SAndroid Build Coastguard Worker                args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
1843*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='with duplicate names')
1844*da0073e9SAndroid Build Coastguard Worker
1845*da0073e9SAndroid Build Coastguard Worker            # matching error (batch dimensions must be alignable)
1846*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1847*da0073e9SAndroid Build Coastguard Worker                torch.bmm, device=device,
1848*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,A:3,B:3'), create('M:3,A:3,B:3')),
1849*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='do not match')
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker            # misalignment (batch dimension is getting contracted)
1852*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1853*da0073e9SAndroid Build Coastguard Worker                torch.bmm, device=device,
1854*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,A:3,B:3'), create('None:3,N:3,B:3')),
1855*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='misaligned')
1856*da0073e9SAndroid Build Coastguard Worker
1857*da0073e9SAndroid Build Coastguard Worker    def test_matmul(self):
1858*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1859*da0073e9SAndroid Build Coastguard Worker            # input tensors are less than 1D
1860*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1861*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1862*da0073e9SAndroid Build Coastguard Worker                args=(create(''), create('A:2')),
1863*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='at least 1D')
1864*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1865*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1866*da0073e9SAndroid Build Coastguard Worker                args=(create('A:2'), create('')),
1867*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='at least 1D')
1868*da0073e9SAndroid Build Coastguard Worker
1869*da0073e9SAndroid Build Coastguard Worker            # 1D @ 1D
1870*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1871*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1872*da0073e9SAndroid Build Coastguard Worker                args=(create('A:2'), create('B:2')),
1873*da0073e9SAndroid Build Coastguard Worker                expected_names=[])
1874*da0073e9SAndroid Build Coastguard Worker
1875*da0073e9SAndroid Build Coastguard Worker            # ND @ 1D
1876*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1877*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1878*da0073e9SAndroid Build Coastguard Worker                args=(create('A:3,C:2'), create('B:2')),
1879*da0073e9SAndroid Build Coastguard Worker                expected_names=['A'])
1880*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1881*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1882*da0073e9SAndroid Build Coastguard Worker                args=(create('A:5,C:3,D:2'), create('B:2')),
1883*da0073e9SAndroid Build Coastguard Worker                expected_names=['A', 'C'])
1884*da0073e9SAndroid Build Coastguard Worker
1885*da0073e9SAndroid Build Coastguard Worker            # 1D @ ND
1886*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1887*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1888*da0073e9SAndroid Build Coastguard Worker                args=(create('C:2'), create('A:2,B:3')),
1889*da0073e9SAndroid Build Coastguard Worker                expected_names=['B'])
1890*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1891*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1892*da0073e9SAndroid Build Coastguard Worker                args=(create('C:2'), create('A:3,B:2,D:5')),
1893*da0073e9SAndroid Build Coastguard Worker                expected_names=['A', 'D'])
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker            # 2D @ 2D
1896*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1897*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1898*da0073e9SAndroid Build Coastguard Worker                args=(create('A:3,B:2'), create('A:2,B:3')),
1899*da0073e9SAndroid Build Coastguard Worker                expected_names=['A', 'B'])
1900*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1901*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1902*da0073e9SAndroid Build Coastguard Worker                args=(create('A:3,B:2'), create('B:2,A:5')),
1903*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='with duplicate names')
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker            # ND @ ND where N >= 2
1906*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1907*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1908*da0073e9SAndroid Build Coastguard Worker                args=(create('C:5,A:3,B:2'), create('A:2,B:3')),
1909*da0073e9SAndroid Build Coastguard Worker                expected_names=['C', 'A', 'B'])
1910*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1911*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1912*da0073e9SAndroid Build Coastguard Worker                args=(create('C:5,A:3,B:2'), create('None:1,A:2,B:3')),
1913*da0073e9SAndroid Build Coastguard Worker                expected_names=['C', 'A', 'B'])
1914*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1915*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1916*da0073e9SAndroid Build Coastguard Worker                args=(create('C:5,A:3,B:2'), create('None:2,None:1,A:2,B:3')),
1917*da0073e9SAndroid Build Coastguard Worker                expected_names=[None, 'C', 'A', 'B'])
1918*da0073e9SAndroid Build Coastguard Worker
1919*da0073e9SAndroid Build Coastguard Worker            # out=
1920*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1921*da0073e9SAndroid Build Coastguard Worker                out_fn(torch.matmul), device=device,
1922*da0073e9SAndroid Build Coastguard Worker                args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
1923*da0073e9SAndroid Build Coastguard Worker                expected_names=('N', 'A', 'B'))
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker            # duplicate names after mm
1926*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1927*da0073e9SAndroid Build Coastguard Worker                torch.bmm, device=device,
1928*da0073e9SAndroid Build Coastguard Worker                args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
1929*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='with duplicate names')
1930*da0073e9SAndroid Build Coastguard Worker
1931*da0073e9SAndroid Build Coastguard Worker            # misalignment (batch dimension is getting contracted)
1932*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1933*da0073e9SAndroid Build Coastguard Worker                torch.matmul, device=device,
1934*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,A:3,B:3'), create('A:3,N:3,B:3')),
1935*da0073e9SAndroid Build Coastguard Worker                maybe_raises_regex='do not match')
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker    def test_mv(self):
1938*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1939*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1940*da0073e9SAndroid Build Coastguard Worker                torch.mv, device=device,
1941*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,C:2'), create('W:2')),
1942*da0073e9SAndroid Build Coastguard Worker                expected_names=('N',))
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker            # left arg is unnamed
1945*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1946*da0073e9SAndroid Build Coastguard Worker                torch.mv, device=device,
1947*da0073e9SAndroid Build Coastguard Worker                args=(create('3,2'), create('W:2')),
1948*da0073e9SAndroid Build Coastguard Worker                expected_names=(None,))
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker            # right arg is unnamed
1951*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1952*da0073e9SAndroid Build Coastguard Worker                torch.mv, device=device,
1953*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3,C:2'), create('2')),
1954*da0073e9SAndroid Build Coastguard Worker                expected_names=('N',))
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Worker            # out=
1957*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1958*da0073e9SAndroid Build Coastguard Worker                out_fn(torch.mv), device=device,
1959*da0073e9SAndroid Build Coastguard Worker                args=(create('0'), create('N:3,C:2'), create('W:2')),
1960*da0073e9SAndroid Build Coastguard Worker                expected_names=('N',))
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker    def test_addmv(self):
1963*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
1964*da0073e9SAndroid Build Coastguard Worker            # full names
1965*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1966*da0073e9SAndroid Build Coastguard Worker                torch.addmv, device=device,
1967*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3'), create('N:3,C:2'), create('H:2')),
1968*da0073e9SAndroid Build Coastguard Worker                expected_names=['N'])
1969*da0073e9SAndroid Build Coastguard Worker
1970*da0073e9SAndroid Build Coastguard Worker            # no name on bias
1971*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1972*da0073e9SAndroid Build Coastguard Worker                torch.addmv, device=device,
1973*da0073e9SAndroid Build Coastguard Worker                args=(create('3'), create('N:3,C:2'), create('H:2')),
1974*da0073e9SAndroid Build Coastguard Worker                expected_names=('N',))
1975*da0073e9SAndroid Build Coastguard Worker
1976*da0073e9SAndroid Build Coastguard Worker            # out=
1977*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1978*da0073e9SAndroid Build Coastguard Worker                out_fn(torch.addmv), device=device,
1979*da0073e9SAndroid Build Coastguard Worker                args=(create('0'), create('N:3'), create('N:3,C:2'), create('H:2')),
1980*da0073e9SAndroid Build Coastguard Worker                expected_names=('N',))
1981*da0073e9SAndroid Build Coastguard Worker
1982*da0073e9SAndroid Build Coastguard Worker            # inplace
1983*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
1984*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.addmv_, device=device,
1985*da0073e9SAndroid Build Coastguard Worker                args=(create('N:3'), create('N:3,C:2'), create('H:2')),
1986*da0073e9SAndroid Build Coastguard Worker                expected_names=('N',))
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker    def test_autograd_ignores_names(self):
1989*da0073e9SAndroid Build Coastguard Worker        # sigmoid forward is supported by named tensors, but sigmoid_backward
1990*da0073e9SAndroid Build Coastguard Worker        # is not (see native_functions.yaml). Test that autograd ignores names
1991*da0073e9SAndroid Build Coastguard Worker        # and that the sigmoid_backward succeeds.
1992*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
1993*da0073e9SAndroid Build Coastguard Worker        x.sigmoid().sum().backward()
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker    def test_tensor_grad_is_unnamed(self):
1996*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, names=(None, None), requires_grad=True)
1997*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
1998*da0073e9SAndroid Build Coastguard Worker        (x * y).sum().backward()
1999*da0073e9SAndroid Build Coastguard Worker
2000*da0073e9SAndroid Build Coastguard Worker        # Check that names weren't propagated
2001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad.names, [None, None])
2002*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad.names, [None, None])
2003*da0073e9SAndroid Build Coastguard Worker
2004*da0073e9SAndroid Build Coastguard Worker    def test_autograd_warns_named_grad(self):
2005*da0073e9SAndroid Build Coastguard Worker        base = torch.randn(3, 3, names=('N', 'C'))
2006*da0073e9SAndroid Build Coastguard Worker        named_grad = base.clone()
2007*da0073e9SAndroid Build Coastguard Worker        base.requires_grad_()
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
2010*da0073e9SAndroid Build Coastguard Worker            # Cause all warnings to always be triggered.
2011*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
2012*da0073e9SAndroid Build Coastguard Worker            base.clone().backward(named_grad)
2013*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(warns), 1)
2014*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
2015*da0073e9SAndroid Build Coastguard Worker                str(warns[0].message).startswith('Autograd was passed a named grad tensor'))
2016*da0073e9SAndroid Build Coastguard Worker
2017*da0073e9SAndroid Build Coastguard Worker    def test_nyi_dimname_overload_msg(self):
2018*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
2019*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"):
2020*da0073e9SAndroid Build Coastguard Worker            x.squeeze_("N")
2021*da0073e9SAndroid Build Coastguard Worker
2022*da0073e9SAndroid Build Coastguard Worker    def test_dot(self):
2023*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
2024*da0073e9SAndroid Build Coastguard Worker            # torch.dot ignores the names of both tensors
2025*da0073e9SAndroid Build Coastguard Worker            self._test_name_inference(
2026*da0073e9SAndroid Build Coastguard Worker                torch.dot, device=device,
2027*da0073e9SAndroid Build Coastguard Worker                args=(create('C:2'), create('W:2')),
2028*da0073e9SAndroid Build Coastguard Worker                expected_names=[])
2029*da0073e9SAndroid Build Coastguard Worker
2030*da0073e9SAndroid Build Coastguard Worker    def test_comparison_ops(self):
2031*da0073e9SAndroid Build Coastguard Worker        for device in get_all_device_types():
2032*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(3, 3, names=('N', 'C'), device=device)
2033*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(3, 3, names=('N', 'C'), device=device)
2034*da0073e9SAndroid Build Coastguard Worker            scalar = torch.randn([], device=device)
2035*da0073e9SAndroid Build Coastguard Worker
2036*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a == b).names, ['N', 'C'])
2037*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a != b).names, ['N', 'C'])
2038*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a > b).names, ['N', 'C'])
2039*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a < b).names, ['N', 'C'])
2040*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a >= b).names, ['N', 'C'])
2041*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a <= b).names, ['N', 'C'])
2042*da0073e9SAndroid Build Coastguard Worker
2043*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a == 1).names, ['N', 'C'])
2044*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a != 1).names, ['N', 'C'])
2045*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a > 1).names, ['N', 'C'])
2046*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a < 1).names, ['N', 'C'])
2047*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a >= 1).names, ['N', 'C'])
2048*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a <= 1).names, ['N', 'C'])
2049*da0073e9SAndroid Build Coastguard Worker
2050*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a == scalar).names, ['N', 'C'])
2051*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a != scalar).names, ['N', 'C'])
2052*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a > scalar).names, ['N', 'C'])
2053*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a < scalar).names, ['N', 'C'])
2054*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a >= scalar).names, ['N', 'C'])
2055*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((a <= scalar).names, ['N', 'C'])
2056*da0073e9SAndroid Build Coastguard Worker
2057*da0073e9SAndroid Build Coastguard Worker            res = torch.empty(3, 3, dtype=torch.bool, device=device)
2058*da0073e9SAndroid Build Coastguard Worker            torch.eq(a, b, out=res)
2059*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2060*da0073e9SAndroid Build Coastguard Worker            torch.ne(a, b, out=res)
2061*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2062*da0073e9SAndroid Build Coastguard Worker            torch.lt(a, b, out=res)
2063*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2064*da0073e9SAndroid Build Coastguard Worker            torch.gt(a, b, out=res)
2065*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2066*da0073e9SAndroid Build Coastguard Worker            torch.le(a, b, out=res)
2067*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2068*da0073e9SAndroid Build Coastguard Worker            torch.ge(a, b, out=res)
2069*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2070*da0073e9SAndroid Build Coastguard Worker
2071*da0073e9SAndroid Build Coastguard Worker            res = torch.isnan(a)
2072*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker            res = torch.isinf(a)
2075*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.names, ['N', 'C'])
2076*da0073e9SAndroid Build Coastguard Worker
2077*da0073e9SAndroid Build Coastguard Worker    def test_support_device_named_grad(self):
2078*da0073e9SAndroid Build Coastguard Worker        named_tensor = torch.randn(3, 3, device='meta')
2079*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'NYI: named tensors only support CPU, CUDA'):
2080*da0073e9SAndroid Build Coastguard Worker            named_tensor.rename_('N', 'C')
2081*da0073e9SAndroid Build Coastguard Worker            named_tensor.names = ['N', 'C']
2082*da0073e9SAndroid Build Coastguard Worker            named_tensor = torch.randn(3, 3, device='meta', names=['N', 'C'])
2083*da0073e9SAndroid Build Coastguard Worker
2084*da0073e9SAndroid Build Coastguard Worker
2085*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
2086*da0073e9SAndroid Build Coastguard Worker    run_tests()
2087