Lines Matching full:device
35 def test_index(self, device): argument
41 reference = consec((3, 3, 3)).to(device)
45 reference[torch.LongTensor().to(device)], reference.new(0, 3, 3)
75 reference_5d = consec((3, 3, 3, 3, 3)).to(device)
88 reference = consec((5, 5, 5)).to(device)
89 idx = torch.LongTensor([2, 4]).to(device)
117 reference = consec((10, 10, 10)).to(device)
143 tensor = torch.DoubleTensor(lst).to(device)
182 def test_advancedindex(self, device, dtype): argument
190 sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0)
198 return torch.LongTensor(indices).to(device)
211 x[ri([0, 2, 4]),], torch.tensor([1, 3, 5], dtype=dtype, device=device)
216 self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device))
219 x[ri([0]),], torch.tensor([-1], dtype=dtype, device=device)
223 x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device)
227 x[ri([2, 3, 4]),], torch.tensor([3, 3, 3], dtype=dtype, device=device)
229 x[ri([0, 2, 4]),] = torch.tensor([5, 4, 3], dtype=dtype, device=device)
231 x[ri([0, 2, 4]),], torch.tensor([5, 4, 3], dtype=dtype, device=device)
251 strided = torch.tensor((), dtype=dtype, device=device)
256 self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
258 strided[ri([0]),], torch.tensor([1], dtype=dtype, device=device)
261 strided[ri([3]),], torch.tensor([7], dtype=dtype, device=device)
264 strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device)
267 strided[ri([1, 2]),], torch.tensor([3, 5], dtype=dtype, device=device)
271 torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device),
275 strided = torch.tensor((), dtype=dtype, device=device)
279 self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
281 strided[ri([0]),], torch.tensor([5], dtype=dtype, device=device)
284 strided[ri([1]),], torch.tensor([9], dtype=dtype, device=device)
287 strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device)
290 strided[ri([0, 1]),], torch.tensor([5, 9], dtype=dtype, device=device)
294 torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device),
303 torch.tensor([1, 3, 5], dtype=dtype, device=device),
307 torch.tensor([2, 4, 6], dtype=dtype, device=device),
313 torch.tensor([1, 2], dtype=dtype, device=device),
317 torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device),
321 torch.tensor([1, 2, 3, 3], dtype=dtype, device=device),
328 torch.tensor([[1, 1], [3, 5]], dtype=dtype, device=device),
335 torch.tensor([[2, 1], [4, 5]], dtype=dtype, device=device),
341 torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device),
347 reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
350 [-1, 2, -4], dtype=dtype, device=device
354 torch.tensor([-1, 2, -4], dtype=dtype, device=device),
357 [[4, 6], [2, 3]], dtype=dtype, device=device
361 torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
367 [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype, device=device
377 torch.tensor([0, 1, 2], dtype=dtype, device=device),
381 torch.tensor([4, 5, 6], dtype=dtype, device=device),
384 reference[ri([0]), ri([0])], torch.tensor([0], dtype=dtype, device=device)
387 reference[ri([2]), ri([1])], torch.tensor([6], dtype=dtype, device=device)
391 torch.tensor([0, 4], dtype=dtype, device=device),
395 torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device),
399 torch.tensor([0, 4, 1, 1], dtype=dtype, device=device),
406 torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device),
413 torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device),
419 torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device),
425 reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
428 [-1, 2, -4], dtype=dtype, device=device
432 torch.tensor([-1, 2, -4], dtype=dtype, device=device),
435 [[4, 6], [2, 3]], dtype=dtype, device=device
439 torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
447 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
448 strided = torch.tensor((), dtype=dtype, device=device)
453 torch.tensor([1, 9], dtype=dtype, device=device),
457 torch.tensor([3, 11], dtype=dtype, device=device),
460 strided[ri([0]), ri([0])], torch.tensor([1], dtype=dtype, device=device)
463 strided[ri([1]), ri([3])], torch.tensor([15], dtype=dtype, device=device)
467 torch.tensor([1, 7], dtype=dtype, device=device),
471 torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device),
475 torch.tensor([1, 3, 9, 9], dtype=dtype, device=device),
482 torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device),
489 torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device),
495 torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device),
503 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
504 strided = torch.tensor((), dtype=dtype, device=device)
507 strided[ri([0]), ri([1])], torch.tensor([11], dtype=dtype, device=device)
511 strided[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device)
514 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
515 strided = torch.tensor((), dtype=dtype, device=device)
519 torch.tensor([11, 17], dtype=dtype, device=device),
522 [-1, 2], dtype=dtype, device=device
526 torch.tensor([-1, 2], dtype=dtype, device=device),
529 reference = torch.arange(0.0, 24, dtype=dtype, device=device).view(3, 8)
530 strided = torch.tensor((), dtype=dtype, device=device)
537 torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device),
540 [[4, 6], [2, 3]], dtype=dtype, device=device
544 torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device),
555 torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device),
558 reference[ri([1]), ...], torch.tensor([[3, 4]], dtype=dtype, device=device)
562 torch.tensor([[2], [4], [6]], dtype=dtype, device=device),
570 reference = torch.empty(10, dtype=dtype, device=device)
571 # can't test cuda because it is a device assert
577 reference[torch.LongTensor([err_idx]).to(device)]
583 tensor = tensor.to(device="cpu")
597 return torch.tensor(npt[idxs], dtype=dtype, device=device)
617 set_numpy(numt, indexer, val), dtype=dtype, device=device
626 dev = cpu.to(device).detach().requires_grad_(True)
628 outdev.backward(gOcpu.to(device))
634 set_tensor = torch.randperm(set_count).view(set_size).double().to(device)
641 reference = torch.arange(0.0, 20, dtype=dtype, device=device).view(4, 5)
668 reference = torch.arange(0.0, 160, dtype=dtype, device=device).view(4, 8, 5)
720 reference = torch.arange(0.0, 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
798 def test_advancedindex_big(self, device): argument
799 reference = torch.arange(0, 123344, dtype=torch.int, device=device)
806 def test_set_item_to_scalar_tensor(self, device): argument
809 z = torch.randn([m, n], device=device)
811 w = torch.tensor(a, requires_grad=True, device=device)
816 def test_single_int(self, device): argument
817 v = torch.randn(5, 7, 3, device=device)
820 def test_multiple_int(self, device): argument
821 v = torch.randn(5, 7, 3, device=device)
825 def test_none(self, device): argument
826 v = torch.randn(5, 7, 3, device=device)
832 def test_step(self, device): argument
833 v = torch.arange(10, device=device)
840 def test_step_assignment(self, device): argument
841 v = torch.zeros(4, 4, device=device)
842 v[0, 1::2] = torch.tensor([3.0, 4.0], device=device)
846 def test_bool_indices(self, device): argument
847 v = torch.randn(5, 7, 3, device=device)
849 [True, False, True, True, False], dtype=torch.bool, device=device
854 v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
856 [True, False, False], dtype=torch.bool, device=device
858 uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
865 v[boolIndices], tensor([True], dtype=torch.bool, device=device)
869 def test_bool_indices_accumulate(self, device): argument
870 mask = torch.zeros(size=(10,), dtype=torch.bool, device=device)
871 y = torch.ones(size=(10, 10), device=device)
873 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
875 def test_multiple_bool_indices(self, device): argument
876 v = torch.randn(5, 7, 3, device=device)
878 mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
879 mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
882 def test_byte_mask(self, device): argument
883 v = torch.randn(5, 7, 3, device=device)
884 mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
891 v = torch.tensor([1.0], device=device)
892 self.assertEqual(v[v == 0], torch.tensor([], device=device))
894 def test_byte_mask_accumulate(self, device): argument
895 mask = torch.zeros(size=(10,), dtype=torch.uint8, device=device)
896 y = torch.ones(size=(10, 10), device=device)
900 self.assertEqual(y, torch.ones(size=(10, 10), device=device))
907 def test_index_put_accumulate_large_tensor(self, device): argument
911 a = torch.ones(N, dtype=dt, device=device)
913 [-2, 0, -2, -1, 0, -1, 1], device=device, dtype=torch.long
915 values = torch.tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt, device=device)
926 a = torch.ones((2, N), dtype=dt, device=device)
927 indices0 = torch.tensor([0, -1, 0, 1], device=device, dtype=torch.long)
928 indices1 = torch.tensor([-2, -1, 0, 1], device=device, dtype=torch.long)
929 values = torch.tensor([12, 13, 10, 11], dtype=dt, device=device)
945 def test_index_put_accumulate_expanded_values(self, device): argument
949 t_dev = t.to(device)
951 indices_dev = [i.to(device) for i in indices]
955 out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
959 out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
964 t_dev = t.to(device)
971 indices_dev = [i.to(device) for i in indices]
975 out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
979 out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
984 def test_index_put_accumulate_non_contiguous(self, device): argument
986 t_dev = t.to(device)
993 indices_dev = [i.to(device) for i in indices]
995 out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True)
1004 def test_index_put_accumulate_with_optional_tensors(self, device): argument
1017 t_dev = t.to(device)
1019 indices_dev = indices.to(device)
1032 def test_index_put_accumulate_duplicate_indices(self, device): argument
1036 delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1)
1039 input = torch.randn(indices.abs().max() + 1, device=device)
1040 values = torch.randn(indices.size(0), device=device)
1052 def test_index_ind_dtype(self, device): argument
1053 x = torch.randn(4, 4, device=device)
1054 ind_long = torch.randint(4, (4,), dtype=torch.long, device=device)
1056 src = torch.randn(4, device=device)
1067 ind_long = torch.arange(4, dtype=torch.long, device=device)
1077 def test_index_put_accumulate_empty(self, device): argument
1079 input = torch.rand([], dtype=torch.float32, device=device)
1081 input.index_put([], torch.tensor([1.0], device=device), True)
1083 def test_multiple_byte_mask(self, device): argument
1084 v = torch.randn(5, 7, 3, device=device)
1086 mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
1087 mask2 = torch.ByteTensor([1, 1, 1]).to(device)
1093 def test_byte_mask2d(self, device): argument
1094 v = torch.randn(5, 7, 3, device=device)
1095 c = torch.randn(5, 7, device=device)
1101 def test_jit_indexing(self, device): argument
1112 data = torch.arange(100, device=device, dtype=torch.float)
1116 device=device,
1123 def test_int_indices(self, device): argument
1124 v = torch.randn(5, 7, 3, device=device)
1145 def test_index_put_src_datatype(self, device, dtype): argument
1146 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
1147 vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
1155 def test_index_src_datatype(self, device, dtype): argument
1156 src = torch.ones(3, 2, 4, device=device, dtype=dtype)
1164 def test_int_indices2d(self, device): argument
1166 x = torch.arange(0, 12, device=device).view(4, 3)
1167 rows = torch.tensor([[0, 0], [3, 3]], device=device)
1168 columns = torch.tensor([[0, 2], [0, 2]], device=device)
1171 def test_int_indices_broadcast(self, device): argument
1173 x = torch.arange(0, 12, device=device).view(4, 3)
1174 rows = torch.tensor([0, 3], device=device)
1175 columns = torch.tensor([0, 2], device=device)
1179 def test_empty_index(self, device): argument
1180 x = torch.arange(0, 12, device=device).view(4, 3)
1181 idx = torch.tensor([], dtype=torch.long, device=device)
1189 mask = torch.zeros(4, 3, device=device).bool()
1193 def test_empty_ndim_index(self, device): argument
1194 x = torch.randn(5, device=device)
1196 torch.empty(0, 2, device=device),
1197 x[torch.empty(0, 2, dtype=torch.int64, device=device)],
1200 x = torch.randn(2, 3, 4, 5, device=device)
1202 torch.empty(2, 0, 6, 4, 5, device=device),
1203 x[:, torch.empty(0, 6, dtype=torch.int64, device=device)],
1206 x = torch.empty(10, 0, device=device)
1212 def test_empty_ndim_index_bool(self, device): argument
1213 x = torch.randn(5, device=device)
1215 IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)]
1218 def test_empty_slice(self, device): argument
1219 x = torch.randn(2, 3, 4, 5, device=device)
1227 def test_index_getitem_copy_bools_slices(self, device): argument
1228 true = torch.tensor(1, dtype=torch.uint8, device=device)
1229 false = torch.tensor(0, dtype=torch.uint8, device=device)
1231 tensors = [torch.randn(2, 3, device=device), torch.tensor(3.0, device=device)]
1241 def test_index_setitem_bools_slices(self, device): argument
1242 true = torch.tensor(1, dtype=torch.uint8, device=device)
1243 false = torch.tensor(0, dtype=torch.uint8, device=device)
1245 tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
1268 def test_index_scalar_with_bool_mask(self, device): argument
1269 a = torch.tensor(1, device=device)
1270 uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
1271 boolMask = torch.tensor(True, dtype=torch.bool, device=device)
1275 a = torch.tensor(True, dtype=torch.bool, device=device)
1279 def test_setitem_expansion_error(self, device): argument
1280 true = torch.tensor(True, device=device)
1281 a = torch.randn(2, 3, device=device)
1290 def test_getitem_scalars(self, device): argument
1291 zero = torch.tensor(0, dtype=torch.int64, device=device)
1292 one = torch.tensor(1, dtype=torch.int64, device=device)
1295 a = torch.randn(2, 3, device=device)
1307 r = torch.randn((), device=device)
1314 def test_setitem_scalars(self, device): argument
1318 a = torch.randn(2, 3, device=device)
1321 b = torch.randn(3, device=device)
1330 r = torch.randn((), device=device)
1338 def test_basic_advanced_combined(self, device): argument
1340 x = torch.arange(0, 12, device=device).view(4, 3)
1354 def test_int_assignment(self, device): argument
1355 x = torch.arange(0, 4, device=device).view(2, 2)
1359 x = torch.arange(0, 4, device=device).view(2, 2)
1360 x[1] = torch.arange(5, 7, device=device)
1363 def test_byte_tensor_assignment(self, device): argument
1364 x = torch.arange(0.0, 16, device=device).view(4, 4)
1365 b = torch.ByteTensor([True, False, True, False]).to(device)
1366 value = torch.tensor([3.0, 4.0, 5.0, 6.0], device=device)
1373 self.assertEqual(x[1], torch.arange(4.0, 8, device=device))
1375 self.assertEqual(x[3], torch.arange(12.0, 16, device=device))
1377 def test_variable_slicing(self, device): argument
1378 x = torch.arange(0, 16, device=device).view(4, 4)
1379 indices = torch.IntTensor([0, 1]).to(device)
1383 def test_ellipsis_tensor(self, device): argument
1384 x = torch.arange(0, 9, device=device).view(3, 3)
1385 idx = torch.tensor([0, 2], device=device)
1389 def test_unravel_index_errors(self, device): argument
1391 torch.unravel_index(torch.tensor(0.5, device=device), (2, 2))
1394 torch.unravel_index(torch.tensor([], device=device), (10, 3, 5))
1400 torch.tensor([1], device=device, dtype=torch.int64),
1408 torch.tensor([1], device=device, dtype=torch.int64), (1, 2, 2.0)
1414 torch.unravel_index(torch.tensor(0, device=device), (2, -3))
1416 def test_invalid_index(self, device): argument
1417 x = torch.arange(0, 16, device=device).view(4, 4)
1420 def test_out_of_bound_index(self, device): argument
1421 x = torch.arange(0, 100, device=device).view(2, 5, 10)
1443 def test_zero_dim_index(self, device): argument
1444 x = torch.tensor(10, device=device)
1454 def test_invalid_device(self, device): argument
1456 b = torch.zeros(5, device=device)
1457 c = torch.tensor([1.0, 2.0], device="cpu")
1466 def test_cpu_indices(self, device): argument
1468 b = torch.zeros(2, device=device)
1469 x = torch.ones(10, device=device)
1471 ref = torch.ones(10, device=device)
1475 self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
1478 def test_take_along_dim(self, device, dtype): argument
1489 shape, device=device, dtype=dtype, noncontiguous=noncontiguous
1500 t = torch.ones((3, 4, 1), device=device)
1501 indices = torch.ones((1, 2, 5), dtype=torch.long, device=device)
1506 t = torch.ones((3, 4, 5), device=device)
1507 indices = torch.ones((3, 0, 5), dtype=torch.long, device=device)
1512 def test_take_along_dim_invalid(self, device, dtype): argument
1515 t = make_tensor(shape, device=device, dtype=dtype)
1543 def test_gather_take_along_dim_cross_device(self, device, dtype): argument
1546 t = make_tensor(shape, device=device, dtype=dtype)
1550 RuntimeError, "Expected all tensors to be on the same device"
1561 RuntimeError, "Expected all tensors to be on the same device"
1572 def test_cuda_broadcast_index_use_deterministic_algorithms(self, device): argument
1579 tensor_b = tensor_a.to(device=device)
1589 tensor_b = tensor_a.to(device=device)
1601 tensor_b = tensor_a.to(device=device)
1607 tensor_b = tensor_a.to(device=device)
1612 def test_index_limits(self, device): argument
1614 t = torch.tensor([], device=device)
1657 def test_index_no_floats(self, device): argument
1658 a = torch.tensor([[[5.0]]], device=device)
1683 def test_none_index(self, device): argument
1685 a = tensor([1, 2, 3], device=device)
1688 def test_empty_tuple_index(self, device): argument
1690 a = tensor([1, 2, 3], device=device)
1694 def test_empty_fancy_index(self, device): argument
1696 a = tensor([1, 2, 3], device=device)
1697 self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1699 b = tensor([], device=device).long()
1700 self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1702 b = tensor([], device=device).float()
1705 def test_ellipsis_index(self, device): argument
1706 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1720 self.assertEqual(a[0, ..., 1], torch.tensor(2, device=device))
1727 def test_single_int_index(self, device): argument
1729 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1739 def test_single_bool_index(self, device): argument
1741 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1746 def test_boolean_shape_mismatch(self, device): argument
1747 arr = torch.ones((5, 4, 3), device=device)
1749 index = tensor([True], device=device)
1752 index = tensor([False] * 6, device=device)
1755 index = torch.ByteTensor(4, 4).to(device).zero_()
1759 def test_boolean_indexing_onedim(self, device): argument
1762 a = tensor([[0.0, 0.0, 0.0]], device=device)
1763 b = tensor([True], device=device)
1767 self.assertEqual(a, tensor([[1.0, 1.0, 1.0]], device=device))
1771 def test_boolean_assignment_value_mismatch(self, device): argument
1774 a = torch.arange(0, 4, device=device)
1777 a[a > -1] = tensor(v).to(device)
1783 def test_boolean_indexing_twodim(self, device): argument
1786 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1789 device=device,
1791 self.assertEqual(a[b], tensor([1, 3, 5, 7, 9], device=device))
1792 self.assertEqual(a[b[1]], tensor([[4, 5, 6]], device=device))
1797 self.assertEqual(a, tensor([[0, 2, 0], [4, 0, 6], [0, 8, 0]], device=device))
1799 def test_boolean_indexing_weirdness(self, device): argument
1801 a = torch.ones((2, 3, 4), device=device)
1804 torch.ones(1, 2, device=device), a[True, [0, 1], True, True, [1], [[2]]]
1808 def test_boolean_indexing_weirdness_tensors(self, device): argument
1810 false = torch.tensor(False, device=device)
1811 true = torch.tensor(True, device=device)
1812 a = torch.ones((2, 3, 4), device=device)
1815 torch.ones(1, 2, device=device), a[true, [0, 1], true, true, [1], [[2]]]
1819 def test_boolean_indexing_alldims(self, device): argument
1820 true = torch.tensor(True, device=device)
1821 a = torch.ones((2, 3), device=device)
1825 def test_boolean_list_indexing(self, device): argument
1828 a = tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], device=device)
1831 self.assertEqual(a[b], tensor([[1, 2, 3]], device=device))
1832 self.assertEqual(a[b, b], tensor([1], device=device))
1833 self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]], device=device))
1834 self.assertEqual(a[c, c], tensor([1, 5], device=device))
1836 def test_everything_returns_views(self, device): argument
1838 a = tensor([5], device=device)
1844 def test_broaderrors_indexing(self, device): argument
1845 a = torch.zeros(5, 5, device=device)
1853 def test_trivial_fancy_out_of_bounds(self, device): argument
1854 a = torch.zeros(5, device=device)
1855 ind = torch.ones(20, dtype=torch.int64, device=device)
1861 ind = torch.ones(20, dtype=torch.int64, device=device)
1866 def test_index_is_larger(self, device): argument
1868 a = torch.zeros((5, 5), device=device)
1869 a[[[0], [1], [2]], [0, 1, 2]] = tensor([2.0, 3.0, 4.0], device=device)
1871 self.assertTrue((a[:3, :3] == tensor([2.0, 3.0, 4.0], device=device)).all())
1873 def test_broadcast_subspace(self, device): argument
1874 a = torch.zeros((100, 100), device=device)
1875 v = torch.arange(0.0, 100, device=device)[:, None]
1876 b = torch.arange(99, -1, -1, device=device).long()
1881 def test_truncate_leading_1s(self, device): argument