1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: named tensor"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY 5*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfTorchDynamo 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import get_all_device_types 8*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple, OrderedDict 9*da0073e9SAndroid Build Coastguard Workerimport itertools 10*da0073e9SAndroid Build Coastguard Workerimport functools 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 13*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 14*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.reduction import ForkingPickler 15*da0073e9SAndroid Build Coastguard Workerimport pickle 16*da0073e9SAndroid Build Coastguard Workerimport io 17*da0073e9SAndroid Build Coastguard Workerimport sys 18*da0073e9SAndroid Build Coastguard Workerimport warnings 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerdef pass_name_to_python_arg_parser(name): 22*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, names=(name,)) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerdef flatten(lst): 26*da0073e9SAndroid Build Coastguard Worker return [item for sublist in lst for item in sublist] 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard WorkerFunction = namedtuple('TestCase', ['name', 'lambd']) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerdef parse_compressed_namedshape(string): 33*da0073e9SAndroid Build Coastguard Worker # This is a metalanguage for describing a shape of a tensor compactly. 34*da0073e9SAndroid Build Coastguard Worker # 'N:3,C:2' -> size = [3, 2], names: ['N', 'C'] 35*da0073e9SAndroid Build Coastguard Worker # 'None:3,None:2' -> size = [3, 2], names: ['None', 'None'] 36*da0073e9SAndroid Build Coastguard Worker # '3,2' -> size = [3, 2], names=None passed to ctor. 37*da0073e9SAndroid Build Coastguard Worker def parse_name(maybe_name): 38*da0073e9SAndroid Build Coastguard Worker maybe_name = maybe_name.strip() 39*da0073e9SAndroid Build Coastguard Worker if maybe_name == 'None': 40*da0073e9SAndroid Build Coastguard Worker return None 41*da0073e9SAndroid Build Coastguard Worker return maybe_name 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker string = string.strip() 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker # '' -> size: [], names:None 46*da0073e9SAndroid Build Coastguard Worker if len(string) == 0: 47*da0073e9SAndroid Build Coastguard Worker return None, [] 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker # '3, 2' -> size = [3, 2], None names. 50*da0073e9SAndroid Build Coastguard Worker if ':' not in string: 51*da0073e9SAndroid Build Coastguard Worker return None, [int(size) for size in string.split(',')] 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker dims = string.split(',') 54*da0073e9SAndroid Build Coastguard Worker tuples = [dim.split(':') for dim in dims] 55*da0073e9SAndroid Build Coastguard Worker return zip(*[(parse_name(name), int(size)) for name, size in tuples]) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerdef create(namedshape, factory=torch.randn): 59*da0073e9SAndroid Build Coastguard Worker # namedshape: str 60*da0073e9SAndroid Build Coastguard Worker names, shape = parse_compressed_namedshape(namedshape) 61*da0073e9SAndroid Build Coastguard Worker return factory(shape, names=names) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerdef out_fn(operator): 65*da0073e9SAndroid Build Coastguard Worker @functools.wraps(operator) 66*da0073e9SAndroid Build Coastguard Worker def fn(*inputs): 67*da0073e9SAndroid Build Coastguard Worker return operator(*inputs[1:], out=inputs[0]) 68*da0073e9SAndroid Build Coastguard Worker return fn 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerclass TestNamedTensor(TestCase): 72*da0073e9SAndroid Build Coastguard Worker def test_aaa_must_run_first_check_experimental_warning(self): 73*da0073e9SAndroid Build Coastguard Worker # TODO(rzou): It would be nice for this to be a "real" python warning. 74*da0073e9SAndroid Build Coastguard Worker # Right now this error message only prints once and doesn't respect 75*da0073e9SAndroid Build Coastguard Worker # warnings.simplefilter behavior (where python users can control whether 76*da0073e9SAndroid Build Coastguard Worker # or not to display warnings once, all the time, or never). 77*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 78*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, names=('N', 'C')) 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(warns), 1) 80*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(warns[0].message).startswith( 81*da0073e9SAndroid Build Coastguard Worker 'Named tensors and all their associated APIs are an experimental feature')) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker def test_trivial(self): 84*da0073e9SAndroid Build Coastguard Worker pass 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def _test_name_inference(self, op, args=(), expected_names=(), device='cpu', 87*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex=None): 88*da0073e9SAndroid Build Coastguard Worker casted_args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg 89*da0073e9SAndroid Build Coastguard Worker for arg in args] 90*da0073e9SAndroid Build Coastguard Worker if maybe_raises_regex is not None: 91*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, maybe_raises_regex): 92*da0073e9SAndroid Build Coastguard Worker result = op(*args) 93*da0073e9SAndroid Build Coastguard Worker return 94*da0073e9SAndroid Build Coastguard Worker result = op(*args) 95*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.names, expected_names, 96*da0073e9SAndroid Build Coastguard Worker msg=f'Name inference for {op.__name__} on device {device} failed') 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker # TODO(rzou): Some form of this check should be added to self.assertEqual. 99*da0073e9SAndroid Build Coastguard Worker # Right now I don't know what it should look like. 100*da0073e9SAndroid Build Coastguard Worker def assertTensorDataAndNamesEqual(self, x, y): 101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, y.names) 102*da0073e9SAndroid Build Coastguard Worker unnamed_x = x.rename(None) 103*da0073e9SAndroid Build Coastguard Worker unnamed_y = y.rename(None) 104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unnamed_x, unnamed_y) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def _test_factory(self, factory, device): 107*da0073e9SAndroid Build Coastguard Worker x = factory([], device=device) 108*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, ()) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker x = factory(1, 2, 3, device=device) 111*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, (None, None, None)) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker x = factory(1, 2, 3, names=None, device=device) 114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, (None, None, None)) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device) 117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, ('N', 'T', 'D')) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker x = factory(1, 2, 3, names=('N', None, 'D'), device=device) 120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, ('N', None, 'D')) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device) 123*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5')) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 126*da0073e9SAndroid Build Coastguard Worker 'a valid identifier contains only'): 127*da0073e9SAndroid Build Coastguard Worker x = factory(2, names=('1',), device=device) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 130*da0073e9SAndroid Build Coastguard Worker 'a valid identifier contains only'): 131*da0073e9SAndroid Build Coastguard Worker x = factory(2, names=('?',), device=device) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Number of names'): 134*da0073e9SAndroid Build Coastguard Worker x = factory(2, 1, names=('N',), device=device) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'): 137*da0073e9SAndroid Build Coastguard Worker x = factory(2, 1, names='N', device=device) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'): 140*da0073e9SAndroid Build Coastguard Worker x = factory(2, 1, 1, names=('N', 'C', 'N'), device=device) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker names64 = ['A' * i for i in range(1, 65)] 143*da0073e9SAndroid Build Coastguard Worker x = factory([1] * 64, names=names64, device=device) 144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, names64) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 147*da0073e9SAndroid Build Coastguard Worker RuntimeError, 148*da0073e9SAndroid Build Coastguard Worker 'only support up to 64 dims'): 149*da0073e9SAndroid Build Coastguard Worker names65 = ['A' * i for i in range(1, 66)] 150*da0073e9SAndroid Build Coastguard Worker x = factory([1] * 65, names=names64, device=device) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("not a bug: Dynamo causes the refcounts to be different") 153*da0073e9SAndroid Build Coastguard Worker def test_none_names_refcount(self, N=10): 154*da0073e9SAndroid Build Coastguard Worker def scope(): 155*da0073e9SAndroid Build Coastguard Worker unnamed = torch.empty(2, 3) 156*da0073e9SAndroid Build Coastguard Worker unnamed.names # materialize [None, None] 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker prev_none_refcnt = sys.getrefcount(None) 159*da0073e9SAndroid Build Coastguard Worker # Ran it N times to reduce flakiness 160*da0073e9SAndroid Build Coastguard Worker [scope() for i in range(N)] 161*da0073e9SAndroid Build Coastguard Worker after_none_refcnt = sys.getrefcount(None) 162*da0073e9SAndroid Build Coastguard Worker self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2, 163*da0073e9SAndroid Build Coastguard Worker msg='Using tensor.names should not change ' 164*da0073e9SAndroid Build Coastguard Worker 'the refcount of Py_None') 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker def test_has_names(self): 167*da0073e9SAndroid Build Coastguard Worker unnamed = torch.empty(2, 3) 168*da0073e9SAndroid Build Coastguard Worker none_named = torch.empty(2, 3, names=(None, None)) 169*da0073e9SAndroid Build Coastguard Worker partially_named = torch.empty(2, 3, names=('N', None)) 170*da0073e9SAndroid Build Coastguard Worker fully_named = torch.empty(2, 3, names=('N', 'C')) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker self.assertFalse(unnamed.has_names()) 173*da0073e9SAndroid Build Coastguard Worker self.assertFalse(none_named.has_names()) 174*da0073e9SAndroid Build Coastguard Worker self.assertTrue(partially_named.has_names()) 175*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fully_named.has_names()) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def test_py3_ellipsis(self): 178*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3, 5, 7) 179*da0073e9SAndroid Build Coastguard Worker output = tensor.refine_names('N', ..., 'C') 180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N', None, None, 'C']) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def test_refine_names(self): 183*da0073e9SAndroid Build Coastguard Worker # Unnamed tensor -> Unnamed tensor 184*da0073e9SAndroid Build Coastguard Worker self._test_name_inference(Tensor.refine_names, 185*da0073e9SAndroid Build Coastguard Worker [create('None:1,None:2,None:3'), 'N', 'C', 'H'], 186*da0073e9SAndroid Build Coastguard Worker ['N', 'C', 'H']) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker # Named tensor -> Named tensor 189*da0073e9SAndroid Build Coastguard Worker self._test_name_inference(Tensor.refine_names, 190*da0073e9SAndroid Build Coastguard Worker [create('N:1,C:2,H:3'), 'N', 'C', 'H'], 191*da0073e9SAndroid Build Coastguard Worker ['N', 'C', 'H']) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker # Partially named tensor -> named tensor 194*da0073e9SAndroid Build Coastguard Worker self._test_name_inference(Tensor.refine_names, 195*da0073e9SAndroid Build Coastguard Worker [create('None:1,C:2,None:3'), None, 'C', 'H'], 196*da0073e9SAndroid Build Coastguard Worker [None, 'C', 'H']) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker # Too few names 199*da0073e9SAndroid Build Coastguard Worker self._test_name_inference(Tensor.refine_names, 200*da0073e9SAndroid Build Coastguard Worker [create('None:2,None:3'), 'N', 'C', 'H'], 201*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex="different number of dims") 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker # Cannot change Tensor[D] to Tensor[N] 204*da0073e9SAndroid Build Coastguard Worker self._test_name_inference(Tensor.refine_names, 205*da0073e9SAndroid Build Coastguard Worker [create('D:3'), 'N'], 206*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex="is different from") 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker # Cannot change Tensor[D] to Tensor[None] 209*da0073e9SAndroid Build Coastguard Worker self._test_name_inference(Tensor.refine_names, 210*da0073e9SAndroid Build Coastguard Worker [create('D:3'), None], 211*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex="'D' is more specific than None") 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker # globbing behavior exists 214*da0073e9SAndroid Build Coastguard Worker self._test_name_inference(Tensor.refine_names, 215*da0073e9SAndroid Build Coastguard Worker [create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'], 216*da0073e9SAndroid Build Coastguard Worker [None, None, 'C', 'H']) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker def test_detach(self): 219*da0073e9SAndroid Build Coastguard Worker names = ['N'] 220*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 221*da0073e9SAndroid Build Coastguard Worker Tensor.detach_, 222*da0073e9SAndroid Build Coastguard Worker [torch.randn(3, requires_grad=True, names=names)], 223*da0073e9SAndroid Build Coastguard Worker names) 224*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 225*da0073e9SAndroid Build Coastguard Worker Tensor.detach, 226*da0073e9SAndroid Build Coastguard Worker [torch.randn(3, requires_grad=True, names=names)], 227*da0073e9SAndroid Build Coastguard Worker names) 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker def test_index_fill(self): 230*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 231*da0073e9SAndroid Build Coastguard Worker expected_names = ('N', 'C') 232*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 5, device=device, names=expected_names) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker output = x.index_fill_('C', torch.tensor([0, 1], device=device), 5) 235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, expected_names) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker output = x.index_fill_('C', torch.tensor([0, 1], device=device), torch.tensor(4.)) 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, expected_names) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker output = x.index_fill('C', torch.tensor([0, 1], device=device), 5) 241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, expected_names) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker output = x.index_fill('C', torch.tensor([0, 1], device=device), torch.tensor(4.)) 244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, expected_names) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker def test_equal(self): 247*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 248*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3, device=device) 249*da0073e9SAndroid Build Coastguard Worker other = tensor.clone() 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(tensor.rename('N', 'C'), other.rename('N', 'C'))) 252*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(tensor.rename('M', 'C'), other.rename('N', 'C'))) 253*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(tensor.rename(None, 'C'), other.rename('N', 'C'))) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker def test_squeeze(self): 256*da0073e9SAndroid Build Coastguard Worker x = create('N:3,C:1,H:1,W:1') 257*da0073e9SAndroid Build Coastguard Worker output = x.squeeze('C') 258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N', 'H', 'W']) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker output = x.squeeze() 261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N']) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker def test_repr(self): 264*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.zeros(2, 3).rename_('N', 'C') 265*da0073e9SAndroid Build Coastguard Worker expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]], names=('N', 'C'))" 266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(named_tensor), expected) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker unnamed_tensor = torch.zeros(2, 3) 269*da0073e9SAndroid Build Coastguard Worker expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]])" 270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(unnamed_tensor), expected) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker none_named_tensor = torch.zeros(2, 3).rename_(None, None) 273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(none_named_tensor), expected) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker def test_diagonal(self): 276*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.zeros(2, 3, 5, 7, names=list('ABCD')) 277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.diagonal().names, ['C', 'D', None]) 278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.diagonal(1, 3).names, ['A', 'C', None]) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.diagonal(outdim='E', dim1='B', dim2='D').names, 281*da0073e9SAndroid Build Coastguard Worker ['A', 'C', 'E']) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def test_max_pooling(self): 284*da0073e9SAndroid Build Coastguard Worker def check_tuple_return(op, inputs, expected_names): 285*da0073e9SAndroid Build Coastguard Worker values, indices = op(*inputs) 286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values.names, expected_names) 287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices.names, expected_names) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list('ABC')) 292*da0073e9SAndroid Build Coastguard Worker named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list('ABCD')) 293*da0073e9SAndroid Build Coastguard Worker named_tensor_3d = torch.zeros(2, 3, 5, 7, 9, device=device, names=list('ABCDE')) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names) 296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names) 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker check_tuple_return(F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names) 300*da0073e9SAndroid Build Coastguard Worker check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names) 301*da0073e9SAndroid Build Coastguard Worker check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker def test_max_pooling_without_names_does_not_warn(self): 304*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 305*da0073e9SAndroid Build Coastguard Worker tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True) 306*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 307*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 308*da0073e9SAndroid Build Coastguard Worker result = F.max_pool2d(tensor_2d, [2, 2]) 309*da0073e9SAndroid Build Coastguard Worker result.sum().backward() 310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(warns), 0) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker def test_no_save_support(self): 313*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.zeros(2, 3, names=('N', 'C')) 314*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 315*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "NYI"): 316*da0073e9SAndroid Build Coastguard Worker torch.save(named_tensor, buf) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker def test_no_pickle_support(self): 319*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.zeros(2, 3, names=('N', 'C')) 320*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "NYI"): 321*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(named_tensor) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker def test_no_multiprocessing_support(self): 324*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.zeros(2, 3, names=('N', 'C')) 325*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 326*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "NYI"): 327*da0073e9SAndroid Build Coastguard Worker ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker def test_big_tensor_repr_has_names(self): 330*da0073e9SAndroid Build Coastguard Worker def check_repr(named_tensor): 331*da0073e9SAndroid Build Coastguard Worker unnamed_tensor = named_tensor.rename(None) 332*da0073e9SAndroid Build Coastguard Worker names_tag = f'names={named_tensor.names}' 333*da0073e9SAndroid Build Coastguard Worker self.assertIn(names_tag, repr(named_tensor)) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker check_repr(torch.randn(128, 3, 64, 64, names=('N', 'C', 'H', 'W'))) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker def test_noncontig_contiguous(self): 338*da0073e9SAndroid Build Coastguard Worker # This type of contiguous is special-cased and therefore needs its own test 339*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 340*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device=device).t().rename_('N', 'C') 341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.contiguous().names, ('N', 'C')) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker def test_copy_transpose(self): 344*da0073e9SAndroid Build Coastguard Worker # This type of copy is special-cased and therefore needs its own test 345*da0073e9SAndroid Build Coastguard Worker def _test(self_names, other_names, expected_names): 346*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 5, names=self_names) 347*da0073e9SAndroid Build Coastguard Worker y = torch.empty(5, 2).t().rename_(*other_names) 348*da0073e9SAndroid Build Coastguard Worker x.copy_(y) 349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.names, expected_names) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker _test(('N', 'C'), ('N', 'C'), ('N', 'C')) 352*da0073e9SAndroid Build Coastguard Worker _test(None, ('N', 'C'), ('N', 'C')) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker def test_rename_(self): 355*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(1, 1, names=('N', 'C')) 356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.rename_(None).names, (None, None)) 357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.rename_('H', 'W').names, ('H', 'W')) 358*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Number of names'): 359*da0073e9SAndroid Build Coastguard Worker tensor.rename_('N', 'C', 'W') 360*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'duplicate names'): 361*da0073e9SAndroid Build Coastguard Worker tensor.rename_('N', 'N') 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker def test_rename(self): 364*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(1, 1, names=('N', 'C')) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.rename(None).names, (None, None)) 367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.rename('H', 'W').names, ('H', 'W')) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker # Check that we didn't modify tensor.names 370*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, ('N', 'C')) 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Number of names'): 373*da0073e9SAndroid Build Coastguard Worker tensor.rename('N', 'C', 'W') 374*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'duplicate names'): 375*da0073e9SAndroid Build Coastguard Worker tensor.rename('N', 'N') 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'): 378*da0073e9SAndroid Build Coastguard Worker tensor.rename(None, N='batch') 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker # rename returns a view on the tensor 381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.rename('H', 'W').data_ptr(), tensor.data_ptr()) 382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr()) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker def test_rename_globber(self): 385*da0073e9SAndroid Build Coastguard Worker scalar = torch.randn([]) 386*da0073e9SAndroid Build Coastguard Worker unnamed_tensor = torch.empty(1, 1, 1, 1) 387*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scalar.rename(None).names, []) 390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scalar.rename('...').names, []) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker # Check that it works with unnamed tensors 393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unnamed_tensor.rename('...').names, unnamed_tensor.names) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unnamed_tensor.rename('...', 'H', 'W').names, 395*da0073e9SAndroid Build Coastguard Worker [None, None, 'H', 'W']) 396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unnamed_tensor.rename('N', '...', 'W').names, 397*da0073e9SAndroid Build Coastguard Worker ['N', None, None, 'W']) 398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unnamed_tensor.rename('N', 'C', '...').names, 399*da0073e9SAndroid Build Coastguard Worker ['N', 'C', None, None]) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker # Check that it works with named tensors 402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename('...').names, named_tensor.names) 403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename('...', 'width').names, 404*da0073e9SAndroid Build Coastguard Worker ['N', 'C', 'H', 'width']) 405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename('batch', 'channels', '...', 'width').names, 406*da0073e9SAndroid Build Coastguard Worker ['batch', 'channels', 'H', 'width']) 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename('batch', '...').names, 408*da0073e9SAndroid Build Coastguard Worker ['batch', 'C', 'H', 'W']) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker # Test empty glob 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unnamed_tensor.rename('...', None, None, None, None).names, 412*da0073e9SAndroid Build Coastguard Worker [None, None, None, None]) 413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename('N', 'C', 'H', '...', 'W').names, 414*da0073e9SAndroid Build Coastguard Worker ['N', 'C', 'H', 'W']) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker # Multiple globs throw 417*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'More than one '): 418*da0073e9SAndroid Build Coastguard Worker named_tensor.rename('...', 'channels', '...') 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker def test_rename_rename_map(self): 421*da0073e9SAndroid Build Coastguard Worker scalar = torch.randn([]) 422*da0073e9SAndroid Build Coastguard Worker unnamed_tensor = torch.empty(1, 1, 1, 1) 423*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): 426*da0073e9SAndroid Build Coastguard Worker scalar.rename(N='batch') 427*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): 428*da0073e9SAndroid Build Coastguard Worker unnamed_tensor.rename(N='batch') 429*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): 430*da0073e9SAndroid Build Coastguard Worker named_tensor.rename(B='batch') 431*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): 432*da0073e9SAndroid Build Coastguard Worker named_tensor.rename(H='height', B='batch') 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename(N='batch').data_ptr(), 435*da0073e9SAndroid Build Coastguard Worker named_tensor.data_ptr()) 436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename(N='batch').names, 437*da0073e9SAndroid Build Coastguard Worker ['batch', 'C', 'H', 'W']) 438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named_tensor.rename(N='batch', H='height').names, 439*da0073e9SAndroid Build Coastguard Worker ['batch', 'C', 'height', 'W']) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker def test_set_names_property(self): 442*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(1, 1, names=('N', 'C')) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker tensor.names = None 445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, (None, None)) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker tensor.names = ('N', 'W') 448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, ('N', 'W')) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Number of names'): 451*da0073e9SAndroid Build Coastguard Worker tensor.names = ['N', 'C', 'W'] 452*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'duplicate names'): 453*da0073e9SAndroid Build Coastguard Worker tensor.names = ['N', 'N'] 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker def test_factory_edge_cases(self): 456*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 457*da0073e9SAndroid Build Coastguard Worker self._test_factory(torch.empty, device) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker def test_factory_coverage(self): 460*da0073e9SAndroid Build Coastguard Worker def _test(factory, device): 461*da0073e9SAndroid Build Coastguard Worker names = ('N', 'T', 'D') 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 464*da0073e9SAndroid Build Coastguard Worker result = factory(1, 2, 3, names=names, device=device) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 467*da0073e9SAndroid Build Coastguard Worker expected = factory(1, 2, 3, device=device).rename_(*names) 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker self.assertTensorDataAndNamesEqual(result, expected) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker supported = [ 472*da0073e9SAndroid Build Coastguard Worker torch.ones, 473*da0073e9SAndroid Build Coastguard Worker torch.rand, 474*da0073e9SAndroid Build Coastguard Worker torch.randn, 475*da0073e9SAndroid Build Coastguard Worker torch.zeros, 476*da0073e9SAndroid Build Coastguard Worker ] 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker for op, device in itertools.product(supported, get_all_device_types()): 479*da0073e9SAndroid Build Coastguard Worker _test(op, device) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker # Test torch.full 482*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 483*da0073e9SAndroid Build Coastguard Worker names = ('N', 'T', 'D') 484*da0073e9SAndroid Build Coastguard Worker result = torch.full([1, 2, 3], 2., names=names, device=device) 485*da0073e9SAndroid Build Coastguard Worker expected = torch.full([1, 2, 3], 2., device=device).rename_(*names) 486*da0073e9SAndroid Build Coastguard Worker self.assertTensorDataAndNamesEqual(result, expected) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker def test_tensor_from_lists(self): 489*da0073e9SAndroid Build Coastguard Worker names = ('N', 'C') 490*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([[1]], names=names) 491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, names) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker names = ('N',) 494*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([1], names=names) 495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, names) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Number of names'): 498*da0073e9SAndroid Build Coastguard Worker names = ('N', 'C') 499*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([1], names=names) 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_NUMPY, "no numpy") 502*da0073e9SAndroid Build Coastguard Worker def test_tensor_from_numpy(self): 503*da0073e9SAndroid Build Coastguard Worker import numpy as np 504*da0073e9SAndroid Build Coastguard Worker arr = np.array([[1]]) 505*da0073e9SAndroid Build Coastguard Worker names = ('N', 'C') 506*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([[1]], names=names) 507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, names) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker def test_tensor_from_tensor(self): 510*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 511*da0073e9SAndroid Build Coastguard Worker names = ('N', 'C') 512*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(x, names=names) 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, names) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def test_tensor_from_named_tensor(self): 516*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, names=('N', 'D')) 517*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(x) 518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, ('N', 'D')) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker # there's no way to distinguish between names=None and not passing in names. 521*da0073e9SAndroid Build Coastguard Worker # If the user passes in names=None they are asking for trouble. 522*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, names=('N', 'D')) 523*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(x, names=None) 524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.names, ('N', 'D')) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1, names=('N', 'D')) 527*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 528*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(x, names=('N', 'C')) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker def test_size(self): 531*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', None, 'C')) 532*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.size('N'), 2) 533*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.size('C'), 5) 534*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '): 535*da0073e9SAndroid Build Coastguard Worker t.size('channels') 536*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '): 537*da0073e9SAndroid Build Coastguard Worker torch.empty(2, 3, 4).size('N') 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker def test_stride(self): 540*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', None, 'C')) 541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.stride('N'), 3 * 5) 542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.stride('C'), 1) 543*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '): 544*da0073e9SAndroid Build Coastguard Worker t.stride('channels') 545*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '): 546*da0073e9SAndroid Build Coastguard Worker torch.empty(2, 3, 4).stride('N') 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker def test_transpose_variants(self): 549*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) 550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.transpose('N', 'C').names, ['C', 'N', 'H', 'W']) 551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.transpose(1, 3).names, ['N', 'W', 'H', 'C']) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 3, names=('N', 'C')) 554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.t().names, ['C', 'N']) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker def test_resize(self): 557*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 558*da0073e9SAndroid Build Coastguard Worker named = torch.randn(2, names=('N',), device=device) 559*da0073e9SAndroid Build Coastguard Worker named.resize_([2]) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(named.names, ['N']) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"): 563*da0073e9SAndroid Build Coastguard Worker named.resize_([3]) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker other_named = torch.randn(2, names=('N',), device=device) 566*da0073e9SAndroid Build Coastguard Worker named.resize_as_(other_named) 567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(other_named.names, ['N']) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker unnamed = torch.randn(2, device=device) 570*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 571*da0073e9SAndroid Build Coastguard Worker RuntimeError, r'names .* are not the same as the computed output names'): 572*da0073e9SAndroid Build Coastguard Worker named.resize_as_(unnamed) 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker unnamed = torch.randn(1, device=device) 575*da0073e9SAndroid Build Coastguard Worker unnamed.resize_as_(named) 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unnamed.names, ['N']) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker def test_cdist(self): 579*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 580*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(3, 1, 2, 7, names=('M', 'N', 'first_group', 'features'), 581*da0073e9SAndroid Build Coastguard Worker device=device) 582*da0073e9SAndroid Build Coastguard Worker other = torch.randn(5, 11, 7, names=('N', 'second_group', 'features'), 583*da0073e9SAndroid Build Coastguard Worker device=device) 584*da0073e9SAndroid Build Coastguard Worker result = torch.cdist(tensor, other) 585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.names, ['M', 'N', 'first_group', 'second_group']) 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker def test_info_smoke(self): 588*da0073e9SAndroid Build Coastguard Worker # Smoke test for info functions / methods / attributes on named tensors. 589*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(1, 1, names=('N', 'D')) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker tensor.device 592*da0073e9SAndroid Build Coastguard Worker tensor.dtype 593*da0073e9SAndroid Build Coastguard Worker tensor.get_device() 594*da0073e9SAndroid Build Coastguard Worker tensor.is_complex() 595*da0073e9SAndroid Build Coastguard Worker tensor.is_floating_point() 596*da0073e9SAndroid Build Coastguard Worker tensor.is_nonzero() 597*da0073e9SAndroid Build Coastguard Worker torch.is_same_size(tensor, tensor) 598*da0073e9SAndroid Build Coastguard Worker torch.is_signed(tensor) 599*da0073e9SAndroid Build Coastguard Worker tensor.layout 600*da0073e9SAndroid Build Coastguard Worker tensor.numel() 601*da0073e9SAndroid Build Coastguard Worker tensor.dim() 602*da0073e9SAndroid Build Coastguard Worker tensor.element_size() 603*da0073e9SAndroid Build Coastguard Worker tensor.is_contiguous() 604*da0073e9SAndroid Build Coastguard Worker tensor.is_cuda 605*da0073e9SAndroid Build Coastguard Worker tensor.is_leaf 606*da0073e9SAndroid Build Coastguard Worker tensor.is_pinned() 607*da0073e9SAndroid Build Coastguard Worker tensor.is_shared() 608*da0073e9SAndroid Build Coastguard Worker tensor.is_sparse 609*da0073e9SAndroid Build Coastguard Worker tensor.ndimension() 610*da0073e9SAndroid Build Coastguard Worker tensor.nelement() 611*da0073e9SAndroid Build Coastguard Worker tensor.shape 612*da0073e9SAndroid Build Coastguard Worker tensor.size() 613*da0073e9SAndroid Build Coastguard Worker tensor.size(1) 614*da0073e9SAndroid Build Coastguard Worker tensor.storage() 615*da0073e9SAndroid Build Coastguard Worker tensor.storage_offset() 616*da0073e9SAndroid Build Coastguard Worker tensor.storage_type() 617*da0073e9SAndroid Build Coastguard Worker tensor.stride() 618*da0073e9SAndroid Build Coastguard Worker tensor.stride(1) 619*da0073e9SAndroid Build Coastguard Worker tensor.data 620*da0073e9SAndroid Build Coastguard Worker tensor.data_ptr() 621*da0073e9SAndroid Build Coastguard Worker tensor.ndim 622*da0073e9SAndroid Build Coastguard Worker tensor.item() 623*da0073e9SAndroid Build Coastguard Worker tensor.type() 624*da0073e9SAndroid Build Coastguard Worker tensor.is_shared() 625*da0073e9SAndroid Build Coastguard Worker tensor.is_signed() 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker def test_autograd_smoke(self): 628*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, names=('N', 'D'), requires_grad=True) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker y = x.clone() 631*da0073e9SAndroid Build Coastguard Worker y.retain_grad() 632*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda x: x) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker # autograd related attributes 637*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(1, 1, names=('N', 'D'), requires_grad=True) 638*da0073e9SAndroid Build Coastguard Worker tensor = tensor.relu() 639*da0073e9SAndroid Build Coastguard Worker tensor.output_nr 640*da0073e9SAndroid Build Coastguard Worker tensor.grad_fn 641*da0073e9SAndroid Build Coastguard Worker tensor.requires_grad 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker def test_split_fns_propagates_names(self): 644*da0073e9SAndroid Build Coastguard Worker fns = [ 645*da0073e9SAndroid Build Coastguard Worker lambda x: x.split(1, 0), 646*da0073e9SAndroid Build Coastguard Worker lambda x: x.split([1, 1], 1), 647*da0073e9SAndroid Build Coastguard Worker lambda x: x.chunk(2, 0), 648*da0073e9SAndroid Build Coastguard Worker ] 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 651*da0073e9SAndroid Build Coastguard Worker orig_tensor = torch.empty(2, 2, names=('N', 'D'), device=device) 652*da0073e9SAndroid Build Coastguard Worker for fn in fns: 653*da0073e9SAndroid Build Coastguard Worker splits = fn(orig_tensor) 654*da0073e9SAndroid Build Coastguard Worker for split in splits: 655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(split.names, orig_tensor.names) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker def test_any_all(self): 658*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 659*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, dtype=torch.bool, device=device, names=('C',)) 660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.any().names, []) 661*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.all().names, []) 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker def test_addcmul_addcdiv(self): 664*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 665*da0073e9SAndroid Build Coastguard Worker names = ['N'] 666*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, device=device, names=names) 667*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, device=device, names=names) 668*da0073e9SAndroid Build Coastguard Worker # avoid division by 0 669*da0073e9SAndroid Build Coastguard Worker c = torch.rand(3, device=device, names=names).clamp_min_(0.1) 670*da0073e9SAndroid Build Coastguard Worker out = torch.randn(3, device=device, names=names) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.addcmul(a, b, c).names, names) 673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.addcmul(a, b, c, out=out).names, names) 674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.addcmul_(b, c).names, names) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.addcdiv(a, b, c).names, names) 677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names) 678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.addcdiv_(b, c).names, names) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker def test_binary_ops(self): 681*da0073e9SAndroid Build Coastguard Worker def test_basic(op): 682*da0073e9SAndroid Build Coastguard Worker a = torch.empty(2, 3, names=('N', 'C')) 683*da0073e9SAndroid Build Coastguard Worker b = torch.empty(3, 2, names=('C', 'N')) 684*da0073e9SAndroid Build Coastguard Worker c = torch.empty(3, names=('C',)) 685*da0073e9SAndroid Build Coastguard Worker d = torch.empty(5, names=('W',)) 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(a, a).names, ('N', 'C')) 688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(a, c).names, ('N', 'C')) 689*da0073e9SAndroid Build Coastguard Worker # TODO: dynamo will throw a slightly different 690*da0073e9SAndroid Build Coastguard Worker # error message because it's adding fake tensors 691*da0073e9SAndroid Build Coastguard Worker # `must match the size of` portion is the dynamo error 692*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"): 693*da0073e9SAndroid Build Coastguard Worker op(a, d) 694*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "do not match|must match the size of"): 695*da0073e9SAndroid Build Coastguard Worker op(a, b) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker def test_wildcard(op): 698*da0073e9SAndroid Build Coastguard Worker a = torch.empty(2, 3, names=('N', 'C')) 699*da0073e9SAndroid Build Coastguard Worker c = torch.empty(2, 3, names=(None, 'C')) 700*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(a, c).names, ('N', 'C')) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker b = torch.empty(2, 3) 703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(a, b).names, ('N', 'C')) 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker d = torch.empty(2, 3, names=('C', None)) 706*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Misaligned"): 707*da0073e9SAndroid Build Coastguard Worker op(d, c) 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Worker def test_mixed_unnamed_named(op, is_inplace): 710*da0073e9SAndroid Build Coastguard Worker named2 = torch.randn(1, 1, names=('N', 'C')) 711*da0073e9SAndroid Build Coastguard Worker unnamed1 = torch.randn(1) 712*da0073e9SAndroid Build Coastguard Worker unnamed2 = torch.randn(1, 1) 713*da0073e9SAndroid Build Coastguard Worker unnamed3 = torch.randn(1, 1, 1) 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker def compute_expected_names(tensor, other): 716*da0073e9SAndroid Build Coastguard Worker assert tensor.has_names() ^ other.has_names() 717*da0073e9SAndroid Build Coastguard Worker named = tensor if tensor.has_names() else other 718*da0073e9SAndroid Build Coastguard Worker unnamed = other if tensor.has_names() else tensor 719*da0073e9SAndroid Build Coastguard Worker unnamed_dim = unnamed.dim() 720*da0073e9SAndroid Build Coastguard Worker if unnamed_dim > named.dim(): 721*da0073e9SAndroid Build Coastguard Worker return [None] * (unnamed_dim - named.dim()) + list(named.names) 722*da0073e9SAndroid Build Coastguard Worker else: 723*da0073e9SAndroid Build Coastguard Worker return named.names 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker inputs = itertools.chain( 726*da0073e9SAndroid Build Coastguard Worker itertools.product([named2], [unnamed1, unnamed2, unnamed3]), 727*da0073e9SAndroid Build Coastguard Worker itertools.product([unnamed1, unnamed2, unnamed3], [named2]), 728*da0073e9SAndroid Build Coastguard Worker ) 729*da0073e9SAndroid Build Coastguard Worker if is_inplace: 730*da0073e9SAndroid Build Coastguard Worker # In-place ops have the constraint that they must not change shape. 731*da0073e9SAndroid Build Coastguard Worker inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()] 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker for tensor, other in inputs: 734*da0073e9SAndroid Build Coastguard Worker expected_names = compute_expected_names(tensor, other) 735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(tensor, other).names, expected_names) 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker def method(name, *args, **kwargs): 738*da0073e9SAndroid Build Coastguard Worker return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))] 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker def function(name, *args, **kwargs): 741*da0073e9SAndroid Build Coastguard Worker return [Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))] 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker def out_function(name, *args, **kwargs): 744*da0073e9SAndroid Build Coastguard Worker out_fn = getattr(torch, name) 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 747*da0073e9SAndroid Build Coastguard Worker result = torch.empty([0], dtype=a.dtype, device=a.device) 748*da0073e9SAndroid Build Coastguard Worker out_fn(a, b, *args, out=result, **kwargs) 749*da0073e9SAndroid Build Coastguard Worker return result 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker return [Function(name, fn)] 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker def fn_method_and_inplace(name, *args, **kwargs): 754*da0073e9SAndroid Build Coastguard Worker return ( 755*da0073e9SAndroid Build Coastguard Worker method(name, *args, **kwargs) + 756*da0073e9SAndroid Build Coastguard Worker method(name + '_', *args, **kwargs) + 757*da0073e9SAndroid Build Coastguard Worker out_function(name, *args, **kwargs) 758*da0073e9SAndroid Build Coastguard Worker ) 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Worker tests = [ 761*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('add'), 762*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('div'), 763*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('mul'), 764*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('sub'), 765*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('pow'), 766*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('atan2'), 767*da0073e9SAndroid Build Coastguard Worker method('copy_'), 768*da0073e9SAndroid Build Coastguard Worker function('floor_divide'), 769*da0073e9SAndroid Build Coastguard Worker function('true_divide'), 770*da0073e9SAndroid Build Coastguard Worker ] 771*da0073e9SAndroid Build Coastguard Worker tests = flatten(tests) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker for name, op in tests: 774*da0073e9SAndroid Build Coastguard Worker test_basic(op) 775*da0073e9SAndroid Build Coastguard Worker test_wildcard(op) 776*da0073e9SAndroid Build Coastguard Worker test_mixed_unnamed_named(op, is_inplace=name.endswith('_')) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker def test_logical_ops(self): 779*da0073e9SAndroid Build Coastguard Worker # Implemented via TensorIterator, so just check that each version 780*da0073e9SAndroid Build Coastguard Worker # (out-of-place, inplace, out=) propagates names. 781*da0073e9SAndroid Build Coastguard Worker def zeros(*args, **kwargs): 782*da0073e9SAndroid Build Coastguard Worker return torch.zeros(*args, dtype=torch.bool, **kwargs) 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker for op in ('logical_xor', 'logical_and', 'logical_or'): 785*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 786*da0073e9SAndroid Build Coastguard Worker getattr(torch, op), 787*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3', zeros), create('N:2,C:3', zeros)), 788*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 791*da0073e9SAndroid Build Coastguard Worker getattr(Tensor, op + '_'), 792*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3', zeros), create('N:2,C:3', zeros)), 793*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 796*da0073e9SAndroid Build Coastguard Worker lambda out, x, y: getattr(torch, op)(x, y, out=out), 797*da0073e9SAndroid Build Coastguard Worker (create('0', zeros), create('N:2,C:3', zeros), create('N:2,C:3', zeros)), 798*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker def test_pow_special(self): 801*da0073e9SAndroid Build Coastguard Worker # There are a few pow cases that don't go through TensorIterator. 802*da0073e9SAndroid Build Coastguard Worker # Test them here. 803*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 804*da0073e9SAndroid Build Coastguard Worker named = torch.randn(2, 3, names=('N', 'C'), device=device) 805*da0073e9SAndroid Build Coastguard Worker unnamed = torch.randn([0], device=device) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker result = torch.pow(named, 0, out=unnamed.clone()) 808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.names, named.names) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker result = torch.pow(named, 1, out=unnamed.clone()) 811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.names, named.names) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker result = torch.pow(1, named, out=unnamed.clone()) 814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.names, named.names) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker def test_out_fn_semantics(self): 817*da0073e9SAndroid Build Coastguard Worker out_fn = torch.abs 818*da0073e9SAndroid Build Coastguard Worker unnamed_tensor = torch.randn(3, 2) 819*da0073e9SAndroid Build Coastguard Worker none_named_tensor = torch.randn(3, 2, names=(None, None)) 820*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.randn(3, 2, names=('N', 'C')) 821*da0073e9SAndroid Build Coastguard Worker partially_named_tensor = torch.randn(3, 2, names=('N', None)) 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 824*da0073e9SAndroid Build Coastguard Worker out_fn(partially_named_tensor, out=named_tensor) 825*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 826*da0073e9SAndroid Build Coastguard Worker out_fn(named_tensor, out=partially_named_tensor) 827*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 828*da0073e9SAndroid Build Coastguard Worker out_fn(none_named_tensor, out=named_tensor) 829*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Name mismatch"): 830*da0073e9SAndroid Build Coastguard Worker out_fn(unnamed_tensor, out=named_tensor) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker output = torch.randn(3, 2) 833*da0073e9SAndroid Build Coastguard Worker out_fn(unnamed_tensor, out=output) 834*da0073e9SAndroid Build Coastguard Worker self.assertFalse(output.has_names()) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker output = torch.randn(3, 2, names=(None, None)) 837*da0073e9SAndroid Build Coastguard Worker out_fn(named_tensor, out=output) 838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, named_tensor.names) 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker output = torch.randn(3, 2) 841*da0073e9SAndroid Build Coastguard Worker out_fn(named_tensor, out=output) 842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, named_tensor.names) 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker output = torch.randn(3, 2, names=(None, None)) 845*da0073e9SAndroid Build Coastguard Worker out_fn(unnamed_tensor, out=output) 846*da0073e9SAndroid Build Coastguard Worker self.assertFalse(output.has_names()) 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker def test_unary_propagate_names_fns(self): 849*da0073e9SAndroid Build Coastguard Worker def _test(testcase, names=('N', 'D'), device='cpu'): 850*da0073e9SAndroid Build Coastguard Worker sizes = [2] * len(names) 851*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(sizes, names=names, device=device) 852*da0073e9SAndroid Build Coastguard Worker try: 853*da0073e9SAndroid Build Coastguard Worker out = testcase.lambd(tensor) 854*da0073e9SAndroid Build Coastguard Worker except RuntimeError as err: 855*da0073e9SAndroid Build Coastguard Worker # Get a better error message by catching the error and asserting. 856*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f'{testcase.name}: {err}') from err 857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, tensor.names, 858*da0073e9SAndroid Build Coastguard Worker msg=testcase.name) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker def fn(name, *args, **kwargs): 861*da0073e9SAndroid Build Coastguard Worker return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))] 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker def method(name, *args, **kwargs): 864*da0073e9SAndroid Build Coastguard Worker return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))] 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker def out_function(name, *args, **kwargs): 867*da0073e9SAndroid Build Coastguard Worker out_fn = getattr(torch, name) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker def fn(tensor): 870*da0073e9SAndroid Build Coastguard Worker result = torch.empty([0], dtype=tensor.dtype, device=tensor.device) 871*da0073e9SAndroid Build Coastguard Worker out_fn(tensor, *args, out=result, **kwargs) 872*da0073e9SAndroid Build Coastguard Worker return result 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker return [Function(name + '_out', fn)] 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker def fn_method_and_inplace(name, *args, **kwargs): 877*da0073e9SAndroid Build Coastguard Worker return ( 878*da0073e9SAndroid Build Coastguard Worker method(name, *args, **kwargs) + 879*da0073e9SAndroid Build Coastguard Worker method(name + '_', *args, **kwargs) + 880*da0073e9SAndroid Build Coastguard Worker out_function(name, *args, **kwargs) 881*da0073e9SAndroid Build Coastguard Worker ) 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker # All of these operate on 2x2 tensors. 884*da0073e9SAndroid Build Coastguard Worker tests = [ 885*da0073e9SAndroid Build Coastguard Worker # unary pointwise 886*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('abs'), 887*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('acos'), 888*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('asin'), 889*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('atan'), 890*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('ceil'), 891*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('clamp', -1, 1), 892*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('clamp_min', -2), 893*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('clamp_max', 2), 894*da0073e9SAndroid Build Coastguard Worker method('cauchy_'), 895*da0073e9SAndroid Build Coastguard Worker method('clone'), 896*da0073e9SAndroid Build Coastguard Worker method('contiguous'), 897*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('cos'), 898*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('cosh'), 899*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('digamma'), 900*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('erf'), 901*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('erfc'), 902*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('erfinv'), 903*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('exp'), 904*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('expm1'), 905*da0073e9SAndroid Build Coastguard Worker method('exponential_'), 906*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('floor'), 907*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('frac'), 908*da0073e9SAndroid Build Coastguard Worker method('geometric_', p=0.5), 909*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('lgamma'), 910*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('log'), 911*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('log10'), 912*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('log1p'), 913*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('log2'), 914*da0073e9SAndroid Build Coastguard Worker method('log_normal_'), 915*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('neg'), 916*da0073e9SAndroid Build Coastguard Worker method('normal_'), 917*da0073e9SAndroid Build Coastguard Worker [Function('polygamma', lambda t: torch.polygamma(1, t))], 918*da0073e9SAndroid Build Coastguard Worker method('polygamma_', 1), 919*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('reciprocal'), 920*da0073e9SAndroid Build Coastguard Worker method('random_', 0, 1), 921*da0073e9SAndroid Build Coastguard Worker method('random_', 1), 922*da0073e9SAndroid Build Coastguard Worker method('random_'), 923*da0073e9SAndroid Build Coastguard Worker method('relu_'), 924*da0073e9SAndroid Build Coastguard Worker method('requires_grad_'), 925*da0073e9SAndroid Build Coastguard Worker method('relu'), 926*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('round'), 927*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('rsqrt'), 928*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('sigmoid'), 929*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('sign'), 930*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('sin'), 931*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('sinh'), 932*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('sqrt'), 933*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('tan'), 934*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('tanh'), 935*da0073e9SAndroid Build Coastguard Worker fn('threshold', 0, 1), 936*da0073e9SAndroid Build Coastguard Worker fn('threshold_', 0, 1), 937*da0073e9SAndroid Build Coastguard Worker out_function('threshold', 0, 1), 938*da0073e9SAndroid Build Coastguard Worker fn_method_and_inplace('trunc'), 939*da0073e9SAndroid Build Coastguard Worker method('uniform_'), 940*da0073e9SAndroid Build Coastguard Worker method('zero_'), 941*da0073e9SAndroid Build Coastguard Worker method('fill_', 1), 942*da0073e9SAndroid Build Coastguard Worker method('fill_', torch.tensor(3.14)), 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker # conversions 945*da0073e9SAndroid Build Coastguard Worker method('to', dtype=torch.long), 946*da0073e9SAndroid Build Coastguard Worker method('to', device='cpu'), 947*da0073e9SAndroid Build Coastguard Worker method('to', torch.empty([])), 948*da0073e9SAndroid Build Coastguard Worker method('bool'), 949*da0073e9SAndroid Build Coastguard Worker method('byte'), 950*da0073e9SAndroid Build Coastguard Worker method('char'), 951*da0073e9SAndroid Build Coastguard Worker method('cpu'), 952*da0073e9SAndroid Build Coastguard Worker method('double'), 953*da0073e9SAndroid Build Coastguard Worker method('float'), 954*da0073e9SAndroid Build Coastguard Worker method('long'), 955*da0073e9SAndroid Build Coastguard Worker method('half'), 956*da0073e9SAndroid Build Coastguard Worker method('int'), 957*da0073e9SAndroid Build Coastguard Worker method('short'), 958*da0073e9SAndroid Build Coastguard Worker method('type', dtype=torch.long), 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker # cumsum and cumprod 961*da0073e9SAndroid Build Coastguard Worker fn('cumsum', 0), 962*da0073e9SAndroid Build Coastguard Worker fn('cumsum', 'D'), 963*da0073e9SAndroid Build Coastguard Worker out_function('cumsum', 'D'), 964*da0073e9SAndroid Build Coastguard Worker fn('cumprod', 0), 965*da0073e9SAndroid Build Coastguard Worker fn('cumprod', 'D'), 966*da0073e9SAndroid Build Coastguard Worker out_function('cumprod', 'D'), 967*da0073e9SAndroid Build Coastguard Worker 968*da0073e9SAndroid Build Coastguard Worker # views 969*da0073e9SAndroid Build Coastguard Worker method('narrow', 0, 0, 1), 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker # creation functions 972*da0073e9SAndroid Build Coastguard Worker fn('empty_like'), 973*da0073e9SAndroid Build Coastguard Worker fn('zeros_like'), 974*da0073e9SAndroid Build Coastguard Worker fn('ones_like'), 975*da0073e9SAndroid Build Coastguard Worker fn('full_like', 3.14), 976*da0073e9SAndroid Build Coastguard Worker fn('rand_like'), 977*da0073e9SAndroid Build Coastguard Worker fn('randn_like'), 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker # bernoulli variants 980*da0073e9SAndroid Build Coastguard Worker method('bernoulli_', 0.5), 981*da0073e9SAndroid Build Coastguard Worker method('bernoulli_', torch.tensor(0.5)), 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker method('softmax', dim=1), 984*da0073e9SAndroid Build Coastguard Worker method('softmax', dim='D'), 985*da0073e9SAndroid Build Coastguard Worker method('log_softmax', dim=1), 986*da0073e9SAndroid Build Coastguard Worker method('log_softmax', dim='D'), 987*da0073e9SAndroid Build Coastguard Worker 988*da0073e9SAndroid Build Coastguard Worker [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))], 989*da0073e9SAndroid Build Coastguard Worker [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))], 990*da0073e9SAndroid Build Coastguard Worker ] 991*da0073e9SAndroid Build Coastguard Worker tests = flatten(tests) 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Worker for testcase, device in itertools.product(tests, get_all_device_types()): 994*da0073e9SAndroid Build Coastguard Worker _test(testcase, device=device) 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker def test_cummax_cummin(self): 997*da0073e9SAndroid Build Coastguard Worker def test_ops(op): 998*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 999*da0073e9SAndroid Build Coastguard Worker names = ('N', 'D') 1000*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(2, 3, names=names) 1001*da0073e9SAndroid Build Coastguard Worker result = op(tensor, 0) 1002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result[0].names, names) 1003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result[1].names, names) 1004*da0073e9SAndroid Build Coastguard Worker test_ops(torch.cummax) 1005*da0073e9SAndroid Build Coastguard Worker test_ops(torch.cummin) 1006*da0073e9SAndroid Build Coastguard Worker 1007*da0073e9SAndroid Build Coastguard Worker def test_logcumsumexp(self): 1008*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1009*da0073e9SAndroid Build Coastguard Worker names = ('N', 'D') 1010*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(2, 3, names=names) 1011*da0073e9SAndroid Build Coastguard Worker result = torch.logcumsumexp(tensor, 'D') 1012*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.names, names) 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker def test_bitwise_not(self): 1015*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1016*da0073e9SAndroid Build Coastguard Worker names = ('N', 'D') 1017*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) 1018*da0073e9SAndroid Build Coastguard Worker result = torch.empty(0, dtype=torch.bool) 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.bitwise_not().names, names) 1021*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bitwise_not(tensor, out=result).names, names) 1022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.bitwise_not_().names, names) 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker def test_logical_not(self): 1025*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1026*da0073e9SAndroid Build Coastguard Worker names = ('N', 'D') 1027*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) 1028*da0073e9SAndroid Build Coastguard Worker result = torch.empty(0, dtype=torch.bool) 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.logical_not().names, names) 1031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.logical_not(tensor, out=result).names, names) 1032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.logical_not_().names, names) 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker def test_bernoulli(self): 1035*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1036*da0073e9SAndroid Build Coastguard Worker names = ('N', 'D') 1037*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(2, 3, names=names) 1038*da0073e9SAndroid Build Coastguard Worker result = torch.empty(0) 1039*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.bernoulli().names, names) 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker torch.bernoulli(tensor, out=result) 1042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.names, names) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker def test_flatten(self): 1045*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W')) 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker # basic 1048*da0073e9SAndroid Build Coastguard Worker out = tensor.flatten('D', 'W', 'features') 1049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, ['N', 'C', 'features']) 1050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker # int overload 1053*da0073e9SAndroid Build Coastguard Worker out = tensor.flatten(2, 4, 'features') 1054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, ['N', 'C', 'features']) 1055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker # list overload 1058*da0073e9SAndroid Build Coastguard Worker out = tensor.flatten(['D', 'H', 'W'], 'features') 1059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, ['N', 'C', 'features']) 1060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker # Non-contiguous flatten: N and H are not "adjacent" in memory. 1063*da0073e9SAndroid Build Coastguard Worker sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D')) 1064*da0073e9SAndroid Build Coastguard Worker sentences = sentences.transpose('T', 'H') 1065*da0073e9SAndroid Build Coastguard Worker out = sentences.flatten('N', 'H', 'N_H') 1066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, ['N_H', 'T', 'D']) 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"): 1069*da0073e9SAndroid Build Coastguard Worker tensor.flatten(['D', 'L'], 'features') 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): 1072*da0073e9SAndroid Build Coastguard Worker tensor.flatten(['D', 'W'], 'features') 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): 1075*da0073e9SAndroid Build Coastguard Worker tensor.flatten(['H', 'D', 'W'], 'features') 1076*da0073e9SAndroid Build Coastguard Worker 1077*da0073e9SAndroid Build Coastguard Worker def test_flatten_nodims(self): 1078*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty((2, 3)) 1079*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be empty"): 1080*da0073e9SAndroid Build Coastguard Worker tensor.flatten((), 'abcd') 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Worker def test_flatten_index_error(self): 1083*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(1, 2) 1084*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, 1085*da0073e9SAndroid Build Coastguard Worker r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"): 1086*da0073e9SAndroid Build Coastguard Worker tensor.flatten(0, 2) 1087*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, 1088*da0073e9SAndroid Build Coastguard Worker r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"): 1089*da0073e9SAndroid Build Coastguard Worker tensor.flatten(0, 2, 'N') 1090*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1091*da0073e9SAndroid Build Coastguard Worker r"flatten\(\) has invalid args: start_dim cannot come after end_dim"): 1092*da0073e9SAndroid Build Coastguard Worker tensor.flatten(1, 0) 1093*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1094*da0073e9SAndroid Build Coastguard Worker r"flatten\(\) has invalid args: start_dim cannot come after end_dim"): 1095*da0073e9SAndroid Build Coastguard Worker tensor.flatten(1, 0, 'N') 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker def test_unflatten(self): 1098*da0073e9SAndroid Build Coastguard Worker # test args: tensor, int, namedshape 1099*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal( 1100*da0073e9SAndroid Build Coastguard Worker torch.ones(4, names=('A',)).unflatten('A', (('A', 2), ('B', 2))), 1101*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, names=('A', 'B')))) 1102*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal( 1103*da0073e9SAndroid Build Coastguard Worker torch.ones(4, names=('A',)).unflatten('A', [('A', 2), ('B', 2)]), 1104*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, names=('A', 'B')))) 1105*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal( 1106*da0073e9SAndroid Build Coastguard Worker torch.ones(4, names=('A',)).unflatten('A', (['A', 2], ['B', 2])), 1107*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, names=('A', 'B')))) 1108*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal( 1109*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)), 1110*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 10, names=('A', 'B1')))) 1111*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal( 1112*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B')) 1113*da0073e9SAndroid Build Coastguard Worker .unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])), 1114*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4')))) 1115*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal( 1116*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 0, names=('A', 'B')) 1117*da0073e9SAndroid Build Coastguard Worker .unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])), 1118*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3')))) 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker # test args: namedtensor, str, namedshape 1121*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal( 1122*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))), 1123*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, 2, names=('A', 'B1', 'B2')))) 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker # test invalid args: namedtensor, str, sizes 1126*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"): 1127*da0073e9SAndroid Build Coastguard Worker torch.tensor([1], names=('A',)).unflatten('A', (1, 1)) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker # test invalid args: namedtensor, int, sizes 1130*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"input is a named tensor but no names were given for unflattened sizes"): 1131*da0073e9SAndroid Build Coastguard Worker torch.tensor([1], names=("A",)).unflatten(0, (1, 1)) 1132*da0073e9SAndroid Build Coastguard Worker 1133*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1134*da0073e9SAndroid Build Coastguard Worker r"Provided sizes \[3, -1\] don't multiply up to the " 1135*da0073e9SAndroid Build Coastguard Worker r"size of dim 1 \('B': 4\) in Tensor\['A', 'B'\]"): 1136*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 3), ('B2', -1))) 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1139*da0073e9SAndroid Build Coastguard Worker r"the unspecified dimension size -1 can be any value and is ambiguous"): 1140*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 0, names=('A', 'B')).unflatten('B', (('B1', 0), ('B2', -1))) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K')) 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker # accepts OrderedDict 1145*da0073e9SAndroid Build Coastguard Worker out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5)))) 1146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K')) 1147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape, (7, 2, 3, 5, 11)) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker # Unflatten left-most 1150*da0073e9SAndroid Build Coastguard Worker out = tensor.unflatten('N', (('N', 7), ('H', 1))) 1151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, ('N', 'H', 'D', 'K')) 1152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11)) 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker # Unflatten right-most 1155*da0073e9SAndroid Build Coastguard Worker out = tensor.unflatten('K', (('K', 11), ('H', 1))) 1156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, ('N', 'D', 'K', 'H')) 1157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1)) 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "don't multiply up to"): 1160*da0073e9SAndroid Build Coastguard Worker tensor.unflatten('D', (('H', 3), ('W', 5))) 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'sizes must be non-empty'): 1163*da0073e9SAndroid Build Coastguard Worker tensor.unflatten('D', None) 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'non-empty'): 1166*da0073e9SAndroid Build Coastguard Worker tensor.unflatten('D', OrderedDict()) 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker def test_unsupported_op_error_msg(self): 1169*da0073e9SAndroid Build Coastguard Worker named = torch.randn(3, 3, names=('N', 'C')) 1170*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1171*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"pdist.+is not yet supported with named tensors"): 1172*da0073e9SAndroid Build Coastguard Worker torch.pdist(named) 1173*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1174*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"as_strided_.+is not yet supported with named tensors"): 1175*da0073e9SAndroid Build Coastguard Worker named.as_strided_((3, 3), (3, 1)) 1176*da0073e9SAndroid Build Coastguard Worker 1177*da0073e9SAndroid Build Coastguard Worker def test_reduction_fns(self): 1178*da0073e9SAndroid Build Coastguard Worker def check_output(output, expected_names): 1179*da0073e9SAndroid Build Coastguard Worker if isinstance(output, torch.Tensor): 1180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, expected_names) 1181*da0073e9SAndroid Build Coastguard Worker return 1182*da0073e9SAndroid Build Coastguard Worker for out in output: 1183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.names, expected_names) 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker def sum_all_outputs(output): 1186*da0073e9SAndroid Build Coastguard Worker if isinstance(output, torch.Tensor): 1187*da0073e9SAndroid Build Coastguard Worker return output.sum() 1188*da0073e9SAndroid Build Coastguard Worker result = 0 1189*da0073e9SAndroid Build Coastguard Worker for out in output: 1190*da0073e9SAndroid Build Coastguard Worker result = out + result 1191*da0073e9SAndroid Build Coastguard Worker return result.sum() 1192*da0073e9SAndroid Build Coastguard Worker 1193*da0073e9SAndroid Build Coastguard Worker def test_simple_reduce(op, device): 1194*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1195*da0073e9SAndroid Build Coastguard Worker check_output(op(t, 1), ['N', 'L']) 1196*da0073e9SAndroid Build Coastguard Worker check_output(op(t, -1), ['N', 'C']) 1197*da0073e9SAndroid Build Coastguard Worker check_output(op(t, 'C'), ['N', 'L']) 1198*da0073e9SAndroid Build Coastguard Worker ops_support_dim_none = [ 1199*da0073e9SAndroid Build Coastguard Worker 'sum', 1200*da0073e9SAndroid Build Coastguard Worker 'mean', 1201*da0073e9SAndroid Build Coastguard Worker 'std', 1202*da0073e9SAndroid Build Coastguard Worker 'var', 1203*da0073e9SAndroid Build Coastguard Worker 'std_mean', 1204*da0073e9SAndroid Build Coastguard Worker 'var_mean', 1205*da0073e9SAndroid Build Coastguard Worker 'nanmean', 1206*da0073e9SAndroid Build Coastguard Worker 'nansum', 1207*da0073e9SAndroid Build Coastguard Worker ] 1208*da0073e9SAndroid Build Coastguard Worker if op.__name__ in ops_support_dim_none: 1209*da0073e9SAndroid Build Coastguard Worker check_output(op(t, None), []) 1210*da0073e9SAndroid Build Coastguard Worker else: 1211*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): 1212*da0073e9SAndroid Build Coastguard Worker op(t, None) 1213*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'): 1214*da0073e9SAndroid Build Coastguard Worker op(t, 'H') 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker def test_autograd_supports_dimname_overload(op, device): 1217*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device, requires_grad=True) 1218*da0073e9SAndroid Build Coastguard Worker sum_all_outputs(op(t, 'C')).backward() 1219*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(t.grad) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker def test_complete_reduce(op, device): 1222*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1223*da0073e9SAndroid Build Coastguard Worker check_output(op(t), []) 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker def test_multidim_reduce(op, device): 1226*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker check_output(op(t, [1, 2]), ['N']) 1229*da0073e9SAndroid Build Coastguard Worker check_output(op(t, [0, -1]), ['C']) 1230*da0073e9SAndroid Build Coastguard Worker check_output(op(t, ['C', 'L']), ['N']) 1231*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): 1232*da0073e9SAndroid Build Coastguard Worker op(t, [None, 'C']) 1233*da0073e9SAndroid Build Coastguard Worker 1234*da0073e9SAndroid Build Coastguard Worker def test_out_variant(op, output_lambda, device): 1235*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1236*da0073e9SAndroid Build Coastguard Worker if output_lambda: 1237*da0073e9SAndroid Build Coastguard Worker out = output_lambda(t) 1238*da0073e9SAndroid Build Coastguard Worker else: 1239*da0073e9SAndroid Build Coastguard Worker out = torch.empty([0], device=device) 1240*da0073e9SAndroid Build Coastguard Worker op(t, 'C', out=out) 1241*da0073e9SAndroid Build Coastguard Worker check_output(out, ['N', 'L']) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker def test_keepdim(op, device): 1244*da0073e9SAndroid Build Coastguard Worker t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) 1245*da0073e9SAndroid Build Coastguard Worker check_output(op(t, 'C', keepdim=True), ['N', 'C', 'L']) 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker def values_and_indices(t): 1248*da0073e9SAndroid Build Coastguard Worker return (torch.empty([0], device=t.device), 1249*da0073e9SAndroid Build Coastguard Worker torch.empty([0], device=t.device, dtype=torch.long)) 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker def kthvalue_wrapper(tensor, *args, **kwargs): 1252*da0073e9SAndroid Build Coastguard Worker # Return the 0-th value 1253*da0073e9SAndroid Build Coastguard Worker return torch.kthvalue(tensor, 1, *args, **kwargs) 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker Case = namedtuple('Case', [ 1256*da0073e9SAndroid Build Coastguard Worker 'op', 1257*da0073e9SAndroid Build Coastguard Worker 'supports_complete_reduce', 1258*da0073e9SAndroid Build Coastguard Worker 'supports_multidim_reduce', 1259*da0073e9SAndroid Build Coastguard Worker 'supports_out_variant', 1260*da0073e9SAndroid Build Coastguard Worker 'supports_keepdim', 1261*da0073e9SAndroid Build Coastguard Worker 'output_lambda', 1262*da0073e9SAndroid Build Coastguard Worker ]) 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker tests = [ 1265*da0073e9SAndroid Build Coastguard Worker Case(torch.sum, True, True, True, True, None), 1266*da0073e9SAndroid Build Coastguard Worker Case(torch.prod, True, False, True, True, None), 1267*da0073e9SAndroid Build Coastguard Worker Case(torch.mean, True, True, True, True, None), 1268*da0073e9SAndroid Build Coastguard Worker Case(torch.var, True, True, True, True, None), 1269*da0073e9SAndroid Build Coastguard Worker Case(torch.std, True, True, True, True, None), 1270*da0073e9SAndroid Build Coastguard Worker Case(torch.std_mean, True, True, False, True, None), 1271*da0073e9SAndroid Build Coastguard Worker Case(torch.var_mean, True, True, False, True, None), 1272*da0073e9SAndroid Build Coastguard Worker Case(torch.min, True, False, True, True, values_and_indices), 1273*da0073e9SAndroid Build Coastguard Worker Case(torch.max, True, False, True, True, values_and_indices), 1274*da0073e9SAndroid Build Coastguard Worker Case(torch.unbind, False, False, False, False, None), 1275*da0073e9SAndroid Build Coastguard Worker Case(torch.logsumexp, False, True, True, True, None), 1276*da0073e9SAndroid Build Coastguard Worker Case(torch.mode, False, False, True, True, values_and_indices), 1277*da0073e9SAndroid Build Coastguard Worker Case(kthvalue_wrapper, False, False, True, True, values_and_indices), 1278*da0073e9SAndroid Build Coastguard Worker Case(torch.median, True, False, True, True, values_and_indices), 1279*da0073e9SAndroid Build Coastguard Worker Case(torch.nanmedian, True, False, True, True, values_and_indices), 1280*da0073e9SAndroid Build Coastguard Worker ] 1281*da0073e9SAndroid Build Coastguard Worker 1282*da0073e9SAndroid Build Coastguard Worker for testcase, device in itertools.product(tests, get_all_device_types()): 1283*da0073e9SAndroid Build Coastguard Worker op = testcase.op 1284*da0073e9SAndroid Build Coastguard Worker test_simple_reduce(op, device) 1285*da0073e9SAndroid Build Coastguard Worker test_autograd_supports_dimname_overload(op, device) 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker if testcase.supports_keepdim: 1288*da0073e9SAndroid Build Coastguard Worker test_keepdim(op, device) 1289*da0073e9SAndroid Build Coastguard Worker if testcase.supports_out_variant: 1290*da0073e9SAndroid Build Coastguard Worker test_out_variant(op, testcase.output_lambda, device) 1291*da0073e9SAndroid Build Coastguard Worker if testcase.supports_complete_reduce: 1292*da0073e9SAndroid Build Coastguard Worker test_complete_reduce(op, device) 1293*da0073e9SAndroid Build Coastguard Worker if testcase.supports_multidim_reduce: 1294*da0073e9SAndroid Build Coastguard Worker test_multidim_reduce(op, device) 1295*da0073e9SAndroid Build Coastguard Worker 1296*da0073e9SAndroid Build Coastguard Worker def test_masked_select(self): 1297*da0073e9SAndroid Build Coastguard Worker # simple 1298*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1299*da0073e9SAndroid Build Coastguard Worker torch.masked_select, 1300*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')), 1301*da0073e9SAndroid Build Coastguard Worker expected_names=[None]) 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker # left broadcast 1304*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1305*da0073e9SAndroid Build Coastguard Worker torch.masked_select, 1306*da0073e9SAndroid Build Coastguard Worker (create('C:3'), (create('2,3') > 0).rename('N', 'C')), 1307*da0073e9SAndroid Build Coastguard Worker expected_names=[None]) 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker # right broadcast 1310*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1311*da0073e9SAndroid Build Coastguard Worker torch.masked_select, 1312*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3'), (create('3') > 0).rename('C')), 1313*da0073e9SAndroid Build Coastguard Worker expected_names=[None]) 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker # error 1316*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1317*da0073e9SAndroid Build Coastguard Worker torch.masked_select, 1318*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3'), (create('3') > 0).rename('D')), 1319*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='do not match') 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker # out= 1322*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1323*da0073e9SAndroid Build Coastguard Worker out_fn(torch.masked_select), 1324*da0073e9SAndroid Build Coastguard Worker (create('0'), create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')), 1325*da0073e9SAndroid Build Coastguard Worker expected_names=[None]) 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 1328*da0073e9SAndroid Build Coastguard Worker # simple 1329*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1330*da0073e9SAndroid Build Coastguard Worker torch.cat, 1331*da0073e9SAndroid Build Coastguard Worker [[create('N:2,C:3'), create('N:2,C:3')]], 1332*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker # error: zero dim 1335*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1336*da0073e9SAndroid Build Coastguard Worker torch.cat, 1337*da0073e9SAndroid Build Coastguard Worker [[create(''), create('')]], 1338*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='zero-dim') 1339*da0073e9SAndroid Build Coastguard Worker 1340*da0073e9SAndroid Build Coastguard Worker # error: names don't match 1341*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1342*da0073e9SAndroid Build Coastguard Worker torch.cat, 1343*da0073e9SAndroid Build Coastguard Worker [[create('N:2,C:3'), create('C:3,N:2')]], 1344*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='do not match') 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker # error: different number of dims 1347*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1348*da0073e9SAndroid Build Coastguard Worker torch.cat, 1349*da0073e9SAndroid Build Coastguard Worker [[create('N:2,C:3'), create('C:3')]], 1350*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='must have same number of dimensions') 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker # out= 1353*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1354*da0073e9SAndroid Build Coastguard Worker out_fn(torch.cat), 1355*da0073e9SAndroid Build Coastguard Worker [create('0'), [create('N:2,C:3'), create('N:2,C:3')]], 1356*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker def test_masked_fill(self): 1359*da0073e9SAndroid Build Coastguard Worker # simple 1360*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1361*da0073e9SAndroid Build Coastguard Worker Tensor.masked_fill, 1362*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1363*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker # left broadcast 1366*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1367*da0073e9SAndroid Build Coastguard Worker Tensor.masked_fill, 1368*da0073e9SAndroid Build Coastguard Worker (create('C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1369*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex="must be less than or equal to") 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker # right broadcast 1372*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1373*da0073e9SAndroid Build Coastguard Worker Tensor.masked_fill, 1374*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3'), (create('3') > 0).rename('C'), 3.14), 1375*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker # error 1378*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1379*da0073e9SAndroid Build Coastguard Worker Tensor.masked_fill, 1380*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3'), (create('3') > 0).rename('D'), 3.14), 1381*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='do not match') 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker # inplace 1384*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1385*da0073e9SAndroid Build Coastguard Worker Tensor.masked_fill_, 1386*da0073e9SAndroid Build Coastguard Worker (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1387*da0073e9SAndroid Build Coastguard Worker expected_names=['N', 'C']) 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker # inplace, computed names don't match output tensor names 1390*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1391*da0073e9SAndroid Build Coastguard Worker Tensor.masked_fill_, 1392*da0073e9SAndroid Build Coastguard Worker (create('N:2,None:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), 1393*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex="not the same as the computed output names") 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker def test_using_seen_interned_string_doesnt_bump_refcount(self): 1397*da0073e9SAndroid Build Coastguard Worker def see_name(): 1398*da0073e9SAndroid Build Coastguard Worker seen_name = 'N' 1399*da0073e9SAndroid Build Coastguard Worker pass_name_to_python_arg_parser(seen_name) 1400*da0073e9SAndroid Build Coastguard Worker 1401*da0073e9SAndroid Build Coastguard Worker see_name() 1402*da0073e9SAndroid Build Coastguard Worker seen_name = 'N' 1403*da0073e9SAndroid Build Coastguard Worker old_refcnt = sys.getrefcount(seen_name) 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Worker pass_name_to_python_arg_parser(seen_name) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker new_refcnt = sys.getrefcount(seen_name) 1408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_refcnt, old_refcnt) 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464 1411*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+") 1412*da0073e9SAndroid Build Coastguard Worker def test_using_unseen_interned_string_bumps_refcount_permanently(self): 1413*da0073e9SAndroid Build Coastguard Worker # Please don't use this as a name in a different test. 1414*da0073e9SAndroid Build Coastguard Worker unseen_name = 'abcdefghi' 1415*da0073e9SAndroid Build Coastguard Worker old_refcnt = sys.getrefcount(unseen_name) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker pass_name_to_python_arg_parser(unseen_name) 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker new_refcnt = sys.getrefcount(unseen_name) 1420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_refcnt, old_refcnt + 1) 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Worker # This test is failing on Python 3.12: https://github.com/pytorch/pytorch/issues/119464 1423*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(sys.version_info >= (3, 12), "Failing on python 3.12+") 1424*da0073e9SAndroid Build Coastguard Worker def test_using_unseen_uninterned_string_refcounts(self): 1425*da0073e9SAndroid Build Coastguard Worker # Please don't use this as a name in a different test. 1426*da0073e9SAndroid Build Coastguard Worker # non-compile-time constants are not interned 1427*da0073e9SAndroid Build Coastguard Worker unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl']) 1428*da0073e9SAndroid Build Coastguard Worker interned_unseen_name = 'abcdefghijkl' 1429*da0073e9SAndroid Build Coastguard Worker self.assertFalse(unseen_name is interned_unseen_name) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker old_uninterned_refcnt = sys.getrefcount(unseen_name) 1432*da0073e9SAndroid Build Coastguard Worker old_interned_refcnt = sys.getrefcount(interned_unseen_name) 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker pass_name_to_python_arg_parser(unseen_name) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker new_uninterned_refcnt = sys.getrefcount(unseen_name) 1437*da0073e9SAndroid Build Coastguard Worker new_interned_refcnt = sys.getrefcount(interned_unseen_name) 1438*da0073e9SAndroid Build Coastguard Worker 1439*da0073e9SAndroid Build Coastguard Worker # Internally, PyTorch should not hold a reference to the uninterned string 1440*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt) 1441*da0073e9SAndroid Build Coastguard Worker 1442*da0073e9SAndroid Build Coastguard Worker # Instead, we should hold a new reference to the interned version. 1443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1) 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker def _test_select(self, device): 1446*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device) 1447*da0073e9SAndroid Build Coastguard Worker y = x.select(1, 1) 1448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.names, ('N', 'H', 'W')) 1449*da0073e9SAndroid Build Coastguard Worker 1450*da0073e9SAndroid Build Coastguard Worker y = x.select('C', 1) 1451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.names, ('N', 'H', 'W')) 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1454*da0073e9SAndroid Build Coastguard Worker RuntimeError, 'Please look up dimensions by name'): 1455*da0073e9SAndroid Build Coastguard Worker y = x.select(None, 1) 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker def test_select(self): 1458*da0073e9SAndroid Build Coastguard Worker self._test_select('cpu') 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'no CUDA') 1461*da0073e9SAndroid Build Coastguard Worker def test_select_cuda(self): 1462*da0073e9SAndroid Build Coastguard Worker self._test_select('cuda') 1463*da0073e9SAndroid Build Coastguard Worker 1464*da0073e9SAndroid Build Coastguard Worker def _test_as_strided(self, device): 1465*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device) 1466*da0073e9SAndroid Build Coastguard Worker y = x.as_strided([2 * 3 * 4 * 5], [1]) 1467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.names, (None,)) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker def test_as_strided(self): 1470*da0073e9SAndroid Build Coastguard Worker self._test_as_strided('cpu') 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, 'no CUDA') 1473*da0073e9SAndroid Build Coastguard Worker def test_as_strided_cuda(self): 1474*da0073e9SAndroid Build Coastguard Worker self._test_as_strided('cuda') 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker def test_no_jit_tracer_support(self): 1477*da0073e9SAndroid Build Coastguard Worker def foo(x): 1478*da0073e9SAndroid Build Coastguard Worker return torch.full(x.shape, 2., names=('N',)) 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): 1481*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1482*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(foo, example_inputs=x) 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker def bar(x): 1485*da0073e9SAndroid Build Coastguard Worker return x.select('N', 1) 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): 1488*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1489*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(bar, example_inputs=x) 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker def test_no_jit_script_support(self): 1492*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1493*da0073e9SAndroid Build Coastguard Worker def foo(x): 1494*da0073e9SAndroid Build Coastguard Worker return x + 1 1495*da0073e9SAndroid Build Coastguard Worker 1496*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'NYI'): 1497*da0073e9SAndroid Build Coastguard Worker foo(torch.randn(2, 3, names=('N', 'C'))) 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 1500*da0073e9SAndroid Build Coastguard Worker def add_names(x): 1501*da0073e9SAndroid Build Coastguard Worker x.names = ('N', 'C') 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1504*da0073e9SAndroid Build Coastguard Worker def return_named_tensor(input): 1505*da0073e9SAndroid Build Coastguard Worker add_names(input) 1506*da0073e9SAndroid Build Coastguard Worker return input 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "NYI"): 1509*da0073e9SAndroid Build Coastguard Worker return_named_tensor(torch.randn(1, 1)) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker def test_align_to(self): 1512*da0073e9SAndroid Build Coastguard Worker # trivial 1513*da0073e9SAndroid Build Coastguard Worker tensor = create('N:3') 1514*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('N') 1515*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N']) 1516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [3]) 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker # unsqueeze behavior 1519*da0073e9SAndroid Build Coastguard Worker tensor = create('N:3') 1520*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('N', 'D') 1521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N', 'D']) 1522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [3, 1]) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker # transpose behavior 1525*da0073e9SAndroid Build Coastguard Worker tensor = create('N:3,C:2') 1526*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('C', 'N') 1527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['C', 'N']) 1528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [2, 3]) 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker # unsqueeze / transpose 1531*da0073e9SAndroid Build Coastguard Worker tensor = create('C:2,N:3,H:5') 1532*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('N', 'H', 'W', 'C') 1533*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N', 'H', 'W', 'C']) 1534*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [3, 5, 1, 2]) 1535*da0073e9SAndroid Build Coastguard Worker 1536*da0073e9SAndroid Build Coastguard Worker # All input dimensions must be named 1537*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"): 1538*da0073e9SAndroid Build Coastguard Worker create('None:2,C:3').align_to('N', 'C') 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker # not enough names 1541*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"): 1542*da0073e9SAndroid Build Coastguard Worker create('N:2,C:3').align_to('C') 1543*da0073e9SAndroid Build Coastguard Worker 1544*da0073e9SAndroid Build Coastguard Worker # names not found 1545*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"): 1546*da0073e9SAndroid Build Coastguard Worker create('N:2,C:3').align_to('D', 'N') 1547*da0073e9SAndroid Build Coastguard Worker 1548*da0073e9SAndroid Build Coastguard Worker def test_align_to_ellipsis(self): 1549*da0073e9SAndroid Build Coastguard Worker tensor = create('N:7,H:3,W:5,C:2') 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker # ... = ['N', 'H', 'W', 'C'] 1552*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('...') 1553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N', 'H', 'W', 'C']) 1554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [7, 3, 5, 2]) 1555*da0073e9SAndroid Build Coastguard Worker 1556*da0073e9SAndroid Build Coastguard Worker # ... = ['H', 'C'] 1557*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('...', 'W', 'N') 1558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['H', 'C', 'W', 'N']) 1559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [3, 2, 5, 7]) 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker # ... = ['N', 'W'] 1562*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('H', 'C', '...') 1563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['H', 'C', 'N', 'W']) 1564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [3, 2, 7, 5]) 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker # ... = ['H', 'C'] 1567*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('W', '...', 'N') 1568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['W', 'H', 'C', 'N']) 1569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [5, 3, 2, 7]) 1570*da0073e9SAndroid Build Coastguard Worker 1571*da0073e9SAndroid Build Coastguard Worker # ... = [] 1572*da0073e9SAndroid Build Coastguard Worker output = tensor.align_to('N', '...', 'C', 'D', 'H', 'W') 1573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N', 'C', 'D', 'H', 'W']) 1574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [7, 2, 1, 3, 5]) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker # Input tensor partially named 1577*da0073e9SAndroid Build Coastguard Worker partially_named = create('None:2,None:3,None:5,C:7') 1578*da0073e9SAndroid Build Coastguard Worker output = partially_named.align_to('C', '...') 1579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['C', None, None, None]) 1580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [7, 2, 3, 5]) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"): 1583*da0073e9SAndroid Build Coastguard Worker partially_named.align_to('C', None, '...') 1584*da0073e9SAndroid Build Coastguard Worker 1585*da0073e9SAndroid Build Coastguard Worker # Input order partially named 1586*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"): 1587*da0073e9SAndroid Build Coastguard Worker tensor.align_to('...', 'N', None) 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker # Input order duplicate names 1590*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "duplicate names"): 1591*da0073e9SAndroid Build Coastguard Worker tensor.align_to('...', 'N', 'N') 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker def test_align_as(self): 1594*da0073e9SAndroid Build Coastguard Worker # align_as calls align_to internally. align_to has pretty substantial tests, 1595*da0073e9SAndroid Build Coastguard Worker # so just test some basic things here. 1596*da0073e9SAndroid Build Coastguard Worker tensor = create('C:2,N:3,H:5') 1597*da0073e9SAndroid Build Coastguard Worker other = create('N:1,H:1,W:1,C:1') 1598*da0073e9SAndroid Build Coastguard Worker output = tensor.align_as(other) 1599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, ['N', 'H', 'W', 'C']) 1600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, [3, 5, 1, 2]) 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Not implemented yet") 1603*da0073e9SAndroid Build Coastguard Worker def test_align_tensors_two_inputs(self): 1604*da0073e9SAndroid Build Coastguard Worker def _test(tensor_namedshape, align_names, expected_sizes, expected_error): 1605*da0073e9SAndroid Build Coastguard Worker tensor_names, tensor_sizes = tensor_namedshape 1606*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty(*tensor_sizes, names=tensor_names) 1607*da0073e9SAndroid Build Coastguard Worker other = torch.empty([1] * len(align_names), names=align_names) 1608*da0073e9SAndroid Build Coastguard Worker if expected_error is not None: 1609*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, expected_error): 1610*da0073e9SAndroid Build Coastguard Worker torch.align_tensors(tensor, other) 1611*da0073e9SAndroid Build Coastguard Worker return 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker output, _ = torch.align_tensors(tensor, other) 1614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, expected_sizes) 1615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.names, align_names) 1616*da0073e9SAndroid Build Coastguard Worker 1617*da0073e9SAndroid Build Coastguard Worker Case = namedtuple('Case', [ 1618*da0073e9SAndroid Build Coastguard Worker 'tensor_namedshape', 1619*da0073e9SAndroid Build Coastguard Worker 'align_names', 1620*da0073e9SAndroid Build Coastguard Worker 'expected_sizes', 1621*da0073e9SAndroid Build Coastguard Worker 'expected_error', 1622*da0073e9SAndroid Build Coastguard Worker ]) 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker tests = [ 1625*da0073e9SAndroid Build Coastguard Worker # basic tests 1626*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=(['C'], [2]), 1627*da0073e9SAndroid Build Coastguard Worker align_names=['C'], 1628*da0073e9SAndroid Build Coastguard Worker expected_sizes=[2], 1629*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1630*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=(['C'], [2]), 1631*da0073e9SAndroid Build Coastguard Worker align_names=['D'], 1632*da0073e9SAndroid Build Coastguard Worker expected_sizes=None, 1633*da0073e9SAndroid Build Coastguard Worker expected_error='not a subsequence'), 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker # single-dim alignment test 1636*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=(['C'], [2]), 1637*da0073e9SAndroid Build Coastguard Worker align_names=['N', 'C'], 1638*da0073e9SAndroid Build Coastguard Worker expected_sizes=[1, 2], 1639*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1640*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[['N'], [2]], 1641*da0073e9SAndroid Build Coastguard Worker align_names=['N', 'C'], 1642*da0073e9SAndroid Build Coastguard Worker expected_sizes=[2, 1], 1643*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker # multiple dim alignment test 1646*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[['N', 'C'], [2, 3]], 1647*da0073e9SAndroid Build Coastguard Worker align_names=['N', 'H', 'C', 'W'], 1648*da0073e9SAndroid Build Coastguard Worker expected_sizes=[2, 1, 3, 1], 1649*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1650*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[['N', 'C'], [2, 3]], 1651*da0073e9SAndroid Build Coastguard Worker align_names=['C', 'H', 'N', 'W'], 1652*da0073e9SAndroid Build Coastguard Worker expected_sizes=None, 1653*da0073e9SAndroid Build Coastguard Worker expected_error='not a subsequence'), 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker # scalar tensor tests 1656*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[None, [[]]], 1657*da0073e9SAndroid Build Coastguard Worker align_names=['N', 'C'], 1658*da0073e9SAndroid Build Coastguard Worker expected_sizes=[1, 1], 1659*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1660*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[[], [[]]], 1661*da0073e9SAndroid Build Coastguard Worker align_names=[None, None], 1662*da0073e9SAndroid Build Coastguard Worker expected_sizes=[1, 1], 1663*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker # unnamed tensor tests 1666*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[None, [2, 3]], 1667*da0073e9SAndroid Build Coastguard Worker align_names=[None, None], 1668*da0073e9SAndroid Build Coastguard Worker expected_sizes=[2, 3], 1669*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1670*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[None, [2, 3]], 1671*da0073e9SAndroid Build Coastguard Worker align_names=[None, None, None], 1672*da0073e9SAndroid Build Coastguard Worker expected_sizes=[1, 2, 3], 1673*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1674*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[None, [2]], 1675*da0073e9SAndroid Build Coastguard Worker align_names=['N'], 1676*da0073e9SAndroid Build Coastguard Worker expected_sizes=None, 1677*da0073e9SAndroid Build Coastguard Worker expected_error='not a subsequence'), 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker # unnamed dim alignment tests 1680*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[[None], [2]], 1681*da0073e9SAndroid Build Coastguard Worker align_names=['N', None], 1682*da0073e9SAndroid Build Coastguard Worker expected_sizes=[1, 2], 1683*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1684*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[[None], [2]], 1685*da0073e9SAndroid Build Coastguard Worker align_names=['N', None, None, None], 1686*da0073e9SAndroid Build Coastguard Worker expected_sizes=[1, 1, 1, 2], 1687*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1688*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[['N'], [2]], 1689*da0073e9SAndroid Build Coastguard Worker align_names=['N', None, None, None], 1690*da0073e9SAndroid Build Coastguard Worker expected_sizes=[2, 1, 1, 1], 1691*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1692*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[[None, 'N', None], [2, 3, 5]], 1693*da0073e9SAndroid Build Coastguard Worker align_names=[None, None, 'N', None], 1694*da0073e9SAndroid Build Coastguard Worker expected_sizes=[1, 2, 3, 5], 1695*da0073e9SAndroid Build Coastguard Worker expected_error=None), 1696*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[[None], [2]], 1697*da0073e9SAndroid Build Coastguard Worker align_names=[None, 'N'], 1698*da0073e9SAndroid Build Coastguard Worker expected_sizes=None, 1699*da0073e9SAndroid Build Coastguard Worker expected_error='absolute position from the right'), 1700*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[None, [2]], 1701*da0073e9SAndroid Build Coastguard Worker align_names=[None, 'N'], 1702*da0073e9SAndroid Build Coastguard Worker expected_sizes=None, 1703*da0073e9SAndroid Build Coastguard Worker expected_error='absolute position from the right'), 1704*da0073e9SAndroid Build Coastguard Worker Case(tensor_namedshape=[[None, 'N'], [2, 3]], 1705*da0073e9SAndroid Build Coastguard Worker align_names=[None, 'C', 'N'], 1706*da0073e9SAndroid Build Coastguard Worker expected_sizes=None, 1707*da0073e9SAndroid Build Coastguard Worker expected_error='absolute position from the right'), 1708*da0073e9SAndroid Build Coastguard Worker ] 1709*da0073e9SAndroid Build Coastguard Worker 1710*da0073e9SAndroid Build Coastguard Worker for test in tests: 1711*da0073e9SAndroid Build Coastguard Worker _test(*test) 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Not implemented yet") 1714*da0073e9SAndroid Build Coastguard Worker def test_align_tensors(self): 1715*da0073e9SAndroid Build Coastguard Worker def reference_fn(*tensors): 1716*da0073e9SAndroid Build Coastguard Worker longest_names = tensors[0].names 1717*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 1718*da0073e9SAndroid Build Coastguard Worker if len(tensor.names) > len(longest_names): 1719*da0073e9SAndroid Build Coastguard Worker longest_names = tensor.names 1720*da0073e9SAndroid Build Coastguard Worker return [tensor.align_to(*longest_names) for tensor in tensors] 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker x = torch.empty(1, 1, names=('N', 'H')) 1723*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2, 3, 5, names=('N', 'C', 'H')) 1724*da0073e9SAndroid Build Coastguard Worker z = torch.empty(2, names=('N',)) 1725*da0073e9SAndroid Build Coastguard Worker output = torch.align_tensors(x, y, z) 1726*da0073e9SAndroid Build Coastguard Worker expected_tensors = reference_fn(x, y, z) 1727*da0073e9SAndroid Build Coastguard Worker for tensor, expected in zip(output, expected_tensors): 1728*da0073e9SAndroid Build Coastguard Worker self.assertTensorDataAndNamesEqual(tensor, expected) 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker def test_mm(self): 1731*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1732*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1733*da0073e9SAndroid Build Coastguard Worker torch.mm, device=device, 1734*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,C:2'), create('W:2,H:5')), 1735*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'H')) 1736*da0073e9SAndroid Build Coastguard Worker 1737*da0073e9SAndroid Build Coastguard Worker # left arg is unnamed 1738*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1739*da0073e9SAndroid Build Coastguard Worker torch.mm, device=device, 1740*da0073e9SAndroid Build Coastguard Worker args=(create('3,2'), create('W:2,H:5')), 1741*da0073e9SAndroid Build Coastguard Worker expected_names=(None, 'H')) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker # right arg is unnamed 1744*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1745*da0073e9SAndroid Build Coastguard Worker torch.mm, device=device, 1746*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,C:2'), create('2,5')), 1747*da0073e9SAndroid Build Coastguard Worker expected_names=('N', None)) 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker # out= 1750*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1751*da0073e9SAndroid Build Coastguard Worker out_fn(torch.mm), device=device, 1752*da0073e9SAndroid Build Coastguard Worker args=(create('0'), create('N:3,C:2'), create('W:2,H:5')), 1753*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'H')) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1756*da0073e9SAndroid Build Coastguard Worker torch.mm, device=device, 1757*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,C:2'), create('W:2,N:5')), 1758*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='with duplicate names') 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker def test_expand(self): 1761*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1762*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1763*da0073e9SAndroid Build Coastguard Worker Tensor.expand, device=device, 1764*da0073e9SAndroid Build Coastguard Worker args=(create('D:1'), [3]), expected_names=('D',)) 1765*da0073e9SAndroid Build Coastguard Worker 1766*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1767*da0073e9SAndroid Build Coastguard Worker Tensor.expand, device=device, 1768*da0073e9SAndroid Build Coastguard Worker args=(create('H:3,W:2'), [10, 3, 3, 2]), 1769*da0073e9SAndroid Build Coastguard Worker expected_names=(None, None, 'H', 'W')) 1770*da0073e9SAndroid Build Coastguard Worker 1771*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1772*da0073e9SAndroid Build Coastguard Worker Tensor.expand, device=device, 1773*da0073e9SAndroid Build Coastguard Worker args=(create('3, 2'), [10, 3, 3, 2]), 1774*da0073e9SAndroid Build Coastguard Worker expected_names=(None, None, None, None)) 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker def test_addmm(self): 1777*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1778*da0073e9SAndroid Build Coastguard Worker # full names 1779*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1780*da0073e9SAndroid Build Coastguard Worker torch.addmm, device=device, 1781*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')), 1782*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'H')) 1783*da0073e9SAndroid Build Coastguard Worker 1784*da0073e9SAndroid Build Coastguard Worker # no name on bias 1785*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1786*da0073e9SAndroid Build Coastguard Worker torch.addmm, device=device, 1787*da0073e9SAndroid Build Coastguard Worker args=(create('3,5'), create('N:3,C:2'), create('W:2,H:5')), 1788*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'H')) 1789*da0073e9SAndroid Build Coastguard Worker 1790*da0073e9SAndroid Build Coastguard Worker # partially named bias 1791*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1792*da0073e9SAndroid Build Coastguard Worker torch.addmm, device=device, 1793*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')), 1794*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'H')) 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker # out= 1797*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1798*da0073e9SAndroid Build Coastguard Worker out_fn(torch.addmm), device=device, 1799*da0073e9SAndroid Build Coastguard Worker args=(create('0'), create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')), 1800*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'H')) 1801*da0073e9SAndroid Build Coastguard Worker 1802*da0073e9SAndroid Build Coastguard Worker # inplace 1803*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1804*da0073e9SAndroid Build Coastguard Worker torch.Tensor.addmm_, device=device, 1805*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')), 1806*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'H')) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1809*da0073e9SAndroid Build Coastguard Worker torch.addmm, device=device, 1810*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,N:5')), 1811*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='with duplicate names') 1812*da0073e9SAndroid Build Coastguard Worker 1813*da0073e9SAndroid Build Coastguard Worker def test_bmm(self): 1814*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1815*da0073e9SAndroid Build Coastguard Worker # full names 1816*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1817*da0073e9SAndroid Build Coastguard Worker torch.bmm, device=device, 1818*da0073e9SAndroid Build Coastguard Worker args=(create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), 1819*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'A', 'B')) 1820*da0073e9SAndroid Build Coastguard Worker 1821*da0073e9SAndroid Build Coastguard Worker # no name on left tensor 1822*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1823*da0073e9SAndroid Build Coastguard Worker torch.bmm, device=device, 1824*da0073e9SAndroid Build Coastguard Worker args=(create('7,3,2'), create('N:7,A:2,B:5')), 1825*da0073e9SAndroid Build Coastguard Worker expected_names=('N', None, 'B')) 1826*da0073e9SAndroid Build Coastguard Worker 1827*da0073e9SAndroid Build Coastguard Worker # no name on right tensor 1828*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1829*da0073e9SAndroid Build Coastguard Worker torch.bmm, device=device, 1830*da0073e9SAndroid Build Coastguard Worker args=(create('N:7,A:3,B:2'), create('7,2,5')), 1831*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'A', None)) 1832*da0073e9SAndroid Build Coastguard Worker 1833*da0073e9SAndroid Build Coastguard Worker # out= 1834*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1835*da0073e9SAndroid Build Coastguard Worker out_fn(torch.bmm), device=device, 1836*da0073e9SAndroid Build Coastguard Worker args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), 1837*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'A', 'B')) 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker # duplicate names after mm 1840*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1841*da0073e9SAndroid Build Coastguard Worker torch.bmm, device=device, 1842*da0073e9SAndroid Build Coastguard Worker args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')), 1843*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='with duplicate names') 1844*da0073e9SAndroid Build Coastguard Worker 1845*da0073e9SAndroid Build Coastguard Worker # matching error (batch dimensions must be alignable) 1846*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1847*da0073e9SAndroid Build Coastguard Worker torch.bmm, device=device, 1848*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,A:3,B:3'), create('M:3,A:3,B:3')), 1849*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='do not match') 1850*da0073e9SAndroid Build Coastguard Worker 1851*da0073e9SAndroid Build Coastguard Worker # misalignment (batch dimension is getting contracted) 1852*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1853*da0073e9SAndroid Build Coastguard Worker torch.bmm, device=device, 1854*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,A:3,B:3'), create('None:3,N:3,B:3')), 1855*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='misaligned') 1856*da0073e9SAndroid Build Coastguard Worker 1857*da0073e9SAndroid Build Coastguard Worker def test_matmul(self): 1858*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1859*da0073e9SAndroid Build Coastguard Worker # input tensors are less than 1D 1860*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1861*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1862*da0073e9SAndroid Build Coastguard Worker args=(create(''), create('A:2')), 1863*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='at least 1D') 1864*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1865*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1866*da0073e9SAndroid Build Coastguard Worker args=(create('A:2'), create('')), 1867*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='at least 1D') 1868*da0073e9SAndroid Build Coastguard Worker 1869*da0073e9SAndroid Build Coastguard Worker # 1D @ 1D 1870*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1871*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1872*da0073e9SAndroid Build Coastguard Worker args=(create('A:2'), create('B:2')), 1873*da0073e9SAndroid Build Coastguard Worker expected_names=[]) 1874*da0073e9SAndroid Build Coastguard Worker 1875*da0073e9SAndroid Build Coastguard Worker # ND @ 1D 1876*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1877*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1878*da0073e9SAndroid Build Coastguard Worker args=(create('A:3,C:2'), create('B:2')), 1879*da0073e9SAndroid Build Coastguard Worker expected_names=['A']) 1880*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1881*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1882*da0073e9SAndroid Build Coastguard Worker args=(create('A:5,C:3,D:2'), create('B:2')), 1883*da0073e9SAndroid Build Coastguard Worker expected_names=['A', 'C']) 1884*da0073e9SAndroid Build Coastguard Worker 1885*da0073e9SAndroid Build Coastguard Worker # 1D @ ND 1886*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1887*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1888*da0073e9SAndroid Build Coastguard Worker args=(create('C:2'), create('A:2,B:3')), 1889*da0073e9SAndroid Build Coastguard Worker expected_names=['B']) 1890*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1891*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1892*da0073e9SAndroid Build Coastguard Worker args=(create('C:2'), create('A:3,B:2,D:5')), 1893*da0073e9SAndroid Build Coastguard Worker expected_names=['A', 'D']) 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker # 2D @ 2D 1896*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1897*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1898*da0073e9SAndroid Build Coastguard Worker args=(create('A:3,B:2'), create('A:2,B:3')), 1899*da0073e9SAndroid Build Coastguard Worker expected_names=['A', 'B']) 1900*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1901*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1902*da0073e9SAndroid Build Coastguard Worker args=(create('A:3,B:2'), create('B:2,A:5')), 1903*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='with duplicate names') 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker # ND @ ND where N >= 2 1906*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1907*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1908*da0073e9SAndroid Build Coastguard Worker args=(create('C:5,A:3,B:2'), create('A:2,B:3')), 1909*da0073e9SAndroid Build Coastguard Worker expected_names=['C', 'A', 'B']) 1910*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1911*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1912*da0073e9SAndroid Build Coastguard Worker args=(create('C:5,A:3,B:2'), create('None:1,A:2,B:3')), 1913*da0073e9SAndroid Build Coastguard Worker expected_names=['C', 'A', 'B']) 1914*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1915*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1916*da0073e9SAndroid Build Coastguard Worker args=(create('C:5,A:3,B:2'), create('None:2,None:1,A:2,B:3')), 1917*da0073e9SAndroid Build Coastguard Worker expected_names=[None, 'C', 'A', 'B']) 1918*da0073e9SAndroid Build Coastguard Worker 1919*da0073e9SAndroid Build Coastguard Worker # out= 1920*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1921*da0073e9SAndroid Build Coastguard Worker out_fn(torch.matmul), device=device, 1922*da0073e9SAndroid Build Coastguard Worker args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), 1923*da0073e9SAndroid Build Coastguard Worker expected_names=('N', 'A', 'B')) 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker # duplicate names after mm 1926*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1927*da0073e9SAndroid Build Coastguard Worker torch.bmm, device=device, 1928*da0073e9SAndroid Build Coastguard Worker args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')), 1929*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='with duplicate names') 1930*da0073e9SAndroid Build Coastguard Worker 1931*da0073e9SAndroid Build Coastguard Worker # misalignment (batch dimension is getting contracted) 1932*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1933*da0073e9SAndroid Build Coastguard Worker torch.matmul, device=device, 1934*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,A:3,B:3'), create('A:3,N:3,B:3')), 1935*da0073e9SAndroid Build Coastguard Worker maybe_raises_regex='do not match') 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker def test_mv(self): 1938*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1939*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1940*da0073e9SAndroid Build Coastguard Worker torch.mv, device=device, 1941*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,C:2'), create('W:2')), 1942*da0073e9SAndroid Build Coastguard Worker expected_names=('N',)) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker # left arg is unnamed 1945*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1946*da0073e9SAndroid Build Coastguard Worker torch.mv, device=device, 1947*da0073e9SAndroid Build Coastguard Worker args=(create('3,2'), create('W:2')), 1948*da0073e9SAndroid Build Coastguard Worker expected_names=(None,)) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker # right arg is unnamed 1951*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1952*da0073e9SAndroid Build Coastguard Worker torch.mv, device=device, 1953*da0073e9SAndroid Build Coastguard Worker args=(create('N:3,C:2'), create('2')), 1954*da0073e9SAndroid Build Coastguard Worker expected_names=('N',)) 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Worker # out= 1957*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1958*da0073e9SAndroid Build Coastguard Worker out_fn(torch.mv), device=device, 1959*da0073e9SAndroid Build Coastguard Worker args=(create('0'), create('N:3,C:2'), create('W:2')), 1960*da0073e9SAndroid Build Coastguard Worker expected_names=('N',)) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker def test_addmv(self): 1963*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 1964*da0073e9SAndroid Build Coastguard Worker # full names 1965*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1966*da0073e9SAndroid Build Coastguard Worker torch.addmv, device=device, 1967*da0073e9SAndroid Build Coastguard Worker args=(create('N:3'), create('N:3,C:2'), create('H:2')), 1968*da0073e9SAndroid Build Coastguard Worker expected_names=['N']) 1969*da0073e9SAndroid Build Coastguard Worker 1970*da0073e9SAndroid Build Coastguard Worker # no name on bias 1971*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1972*da0073e9SAndroid Build Coastguard Worker torch.addmv, device=device, 1973*da0073e9SAndroid Build Coastguard Worker args=(create('3'), create('N:3,C:2'), create('H:2')), 1974*da0073e9SAndroid Build Coastguard Worker expected_names=('N',)) 1975*da0073e9SAndroid Build Coastguard Worker 1976*da0073e9SAndroid Build Coastguard Worker # out= 1977*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1978*da0073e9SAndroid Build Coastguard Worker out_fn(torch.addmv), device=device, 1979*da0073e9SAndroid Build Coastguard Worker args=(create('0'), create('N:3'), create('N:3,C:2'), create('H:2')), 1980*da0073e9SAndroid Build Coastguard Worker expected_names=('N',)) 1981*da0073e9SAndroid Build Coastguard Worker 1982*da0073e9SAndroid Build Coastguard Worker # inplace 1983*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 1984*da0073e9SAndroid Build Coastguard Worker torch.Tensor.addmv_, device=device, 1985*da0073e9SAndroid Build Coastguard Worker args=(create('N:3'), create('N:3,C:2'), create('H:2')), 1986*da0073e9SAndroid Build Coastguard Worker expected_names=('N',)) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker def test_autograd_ignores_names(self): 1989*da0073e9SAndroid Build Coastguard Worker # sigmoid forward is supported by named tensors, but sigmoid_backward 1990*da0073e9SAndroid Build Coastguard Worker # is not (see native_functions.yaml). Test that autograd ignores names 1991*da0073e9SAndroid Build Coastguard Worker # and that the sigmoid_backward succeeds. 1992*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, names=('N', 'C'), requires_grad=True) 1993*da0073e9SAndroid Build Coastguard Worker x.sigmoid().sum().backward() 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker def test_tensor_grad_is_unnamed(self): 1996*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, names=(None, None), requires_grad=True) 1997*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 3, names=('N', 'C'), requires_grad=True) 1998*da0073e9SAndroid Build Coastguard Worker (x * y).sum().backward() 1999*da0073e9SAndroid Build Coastguard Worker 2000*da0073e9SAndroid Build Coastguard Worker # Check that names weren't propagated 2001*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad.names, [None, None]) 2002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad.names, [None, None]) 2003*da0073e9SAndroid Build Coastguard Worker 2004*da0073e9SAndroid Build Coastguard Worker def test_autograd_warns_named_grad(self): 2005*da0073e9SAndroid Build Coastguard Worker base = torch.randn(3, 3, names=('N', 'C')) 2006*da0073e9SAndroid Build Coastguard Worker named_grad = base.clone() 2007*da0073e9SAndroid Build Coastguard Worker base.requires_grad_() 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 2010*da0073e9SAndroid Build Coastguard Worker # Cause all warnings to always be triggered. 2011*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 2012*da0073e9SAndroid Build Coastguard Worker base.clone().backward(named_grad) 2013*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(warns), 1) 2014*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2015*da0073e9SAndroid Build Coastguard Worker str(warns[0].message).startswith('Autograd was passed a named grad tensor')) 2016*da0073e9SAndroid Build Coastguard Worker 2017*da0073e9SAndroid Build Coastguard Worker def test_nyi_dimname_overload_msg(self): 2018*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 2019*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"): 2020*da0073e9SAndroid Build Coastguard Worker x.squeeze_("N") 2021*da0073e9SAndroid Build Coastguard Worker 2022*da0073e9SAndroid Build Coastguard Worker def test_dot(self): 2023*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 2024*da0073e9SAndroid Build Coastguard Worker # torch.dot ignores the names of both tensors 2025*da0073e9SAndroid Build Coastguard Worker self._test_name_inference( 2026*da0073e9SAndroid Build Coastguard Worker torch.dot, device=device, 2027*da0073e9SAndroid Build Coastguard Worker args=(create('C:2'), create('W:2')), 2028*da0073e9SAndroid Build Coastguard Worker expected_names=[]) 2029*da0073e9SAndroid Build Coastguard Worker 2030*da0073e9SAndroid Build Coastguard Worker def test_comparison_ops(self): 2031*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 2032*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, names=('N', 'C'), device=device) 2033*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, names=('N', 'C'), device=device) 2034*da0073e9SAndroid Build Coastguard Worker scalar = torch.randn([], device=device) 2035*da0073e9SAndroid Build Coastguard Worker 2036*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a == b).names, ['N', 'C']) 2037*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a != b).names, ['N', 'C']) 2038*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a > b).names, ['N', 'C']) 2039*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a < b).names, ['N', 'C']) 2040*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a >= b).names, ['N', 'C']) 2041*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a <= b).names, ['N', 'C']) 2042*da0073e9SAndroid Build Coastguard Worker 2043*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a == 1).names, ['N', 'C']) 2044*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a != 1).names, ['N', 'C']) 2045*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a > 1).names, ['N', 'C']) 2046*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a < 1).names, ['N', 'C']) 2047*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a >= 1).names, ['N', 'C']) 2048*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a <= 1).names, ['N', 'C']) 2049*da0073e9SAndroid Build Coastguard Worker 2050*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a == scalar).names, ['N', 'C']) 2051*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a != scalar).names, ['N', 'C']) 2052*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a > scalar).names, ['N', 'C']) 2053*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a < scalar).names, ['N', 'C']) 2054*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a >= scalar).names, ['N', 'C']) 2055*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a <= scalar).names, ['N', 'C']) 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Worker res = torch.empty(3, 3, dtype=torch.bool, device=device) 2058*da0073e9SAndroid Build Coastguard Worker torch.eq(a, b, out=res) 2059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2060*da0073e9SAndroid Build Coastguard Worker torch.ne(a, b, out=res) 2061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2062*da0073e9SAndroid Build Coastguard Worker torch.lt(a, b, out=res) 2063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2064*da0073e9SAndroid Build Coastguard Worker torch.gt(a, b, out=res) 2065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2066*da0073e9SAndroid Build Coastguard Worker torch.le(a, b, out=res) 2067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2068*da0073e9SAndroid Build Coastguard Worker torch.ge(a, b, out=res) 2069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2070*da0073e9SAndroid Build Coastguard Worker 2071*da0073e9SAndroid Build Coastguard Worker res = torch.isnan(a) 2072*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker res = torch.isinf(a) 2075*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.names, ['N', 'C']) 2076*da0073e9SAndroid Build Coastguard Worker 2077*da0073e9SAndroid Build Coastguard Worker def test_support_device_named_grad(self): 2078*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.randn(3, 3, device='meta') 2079*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'NYI: named tensors only support CPU, CUDA'): 2080*da0073e9SAndroid Build Coastguard Worker named_tensor.rename_('N', 'C') 2081*da0073e9SAndroid Build Coastguard Worker named_tensor.names = ['N', 'C'] 2082*da0073e9SAndroid Build Coastguard Worker named_tensor = torch.randn(3, 3, device='meta', names=['N', 'C']) 2083*da0073e9SAndroid Build Coastguard Worker 2084*da0073e9SAndroid Build Coastguard Worker 2085*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 2086*da0073e9SAndroid Build Coastguard Worker run_tests() 2087