Lines Matching full:mps

1 # Owner(s): ["module: mps"]
29 import torch.backends.mps
171 …ss for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
190 …ss for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
196 # Failures due to lack of implementation of downstream functions on MPS backend
200 # Exception: Caused by sample input at index 3 on MPS
225 if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
230 …if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and produ…
601 # inconsistency errors between cpu and mps, max seen atol is 2
616 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
635 # inconsistency errors between cpu and mps, max seen atol is 2
641 # on both cpu and mps there are test cases that might produce inf result
647 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0')
663 # Failures due to lack of op implementation on MPS backend
830 # MPS: input sizes must be divisible by output sizes
840 # Convolution for integral types is not supported on MPS
855 # GEMM on MPS is not supported for integral types
874 # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as
875 # the MPS framework doesn't support float64
969 # mps vs cpu:
979 # float output for float16 input on MPS
984 # Failures due to lack of implementation of downstream functions on MPS backend
997 # CPU: empty is returning all 0's and there is a mismatch with MPS
1039 …if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_ve…
1044 if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2):
1054 if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()):
1091 # MPS does not support tensor dimensions > 16
1116 if not torch.backends.mps.is_available():
1117 print('MPS not available, skipping tests', file=sys.stderr)
1124 # Determine whether to enable MPS memory leak check (uses same code as CUDA).
1141 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1144 torch.mps.empty_cache()
1147 self.caching_allocator_before = torch.mps.current_allocated_memory()
1148 self.driver_before = torch.mps.driver_allocated_memory()
1157 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1167 torch.mps.empty_cache()
1172 caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1173 driver_mem_allocated = torch.mps.driver_allocated_memory()
1191 msg = ("MPS caching allocator reports a memory leak not "
1195 … f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
1199 msg = (f"MPS driver API confirmed a leak in {self.name}! "
1202 … f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
1209 autocast_tensor_A = torch.rand((8, 8), device="mps")
1210 autocast_tensor_B = torch.rand((8, 8), device="mps")
1216 with torch.autocast(device_type="mps"):
1229 # Expand TestCase class with Memory Leak Detection on MPS device
1237 # Wraps the tested method if we should do MPS memory check.
1266 l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
1272 with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
1279 x = x.to(device='mps', dtype=torch.float16)
1281 a = torch.randn(128, 128, device='mps', dtype=torch.float16)
1282 # Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
1284 torch.mps.empty_cache()
1285 driver_before = torch.mps.driver_allocated_memory()
1287 torch.mps.empty_cache()
1288 driver_after = torch.mps.driver_allocated_memory()
1304 input = torch.rand(channels, requires_grad=True, device='mps')
1307 input = torch.rand(width, height, requires_grad=True, device='mps').T
1312 … input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
1364 input = torch.rand(channels, requires_grad=True, device='mps')
1366 input = torch.rand(height, width, requires_grad=True, device='mps')
1369 … input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
1466 device="mps")
1469 device="mps")
1470 self._testRelu(np.array([]).astype(t), device="mps")
1471 self._testReluInPlace(np.array([]).astype(t), device="mps")
1476 tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
1478 tensor1_mps = torch.randn(shape_tensor_1, device="mps")
1481 tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
1483 tensor2_mps = torch.randn(shape_tensor_2, device="mps")
1527 mps_x = cpu_x.detach().clone().to('mps')
1547 mps_grad = cpu_grad.to('mps')
1609 x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1615 def test_exp(self, device="mps", dtype=torch.float):
1618 a = torch.tensor(v, dtype=dtype, device="mps") * b
1621 def test_conv_raises_error(self, device='mps', dtype=torch.float):
1622 conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps')
1626 y = conv(x.to("mps"))
1628 def test_triu_inf(self, device="mps", dtype=torch.float):
1631 mask_mps = mask.clone().detach().to('mps')
1636 def test_exp1(self, device="mps", dtype=torch.float):
1640 # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
1647 x = torch.rand((256, 10), device='mps')
1659 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
1668 mps_grad = cpu_grad.to('mps')
1678 device="mps")
1683 tensor = torch.zeros(shape, device='mps', dtype=dtype)
1698 tensor = torch.ones(shape, device="mps")
1708 tensor = torch.ones(shape, device="mps")
1709 val_tensor_mps = torch.tensor(val, device="mps")
1721 def test_cdist_large(self, device="mps"):
1729 def test_cdist_large_batch(self, device="mps"):
1737 def test_cdist_non_contiguous(self, device="mps"):
1763 def test_cdist_non_contiguous_batch(self, device="mps"):
1789 def test_cdist_euclidean_large(self, device="mps"):
1808 def test_cdist_same_inputs(self, device="mps"):
1832 def test_cdist_norm(self, device="mps"):
1849 def test_cdist_norm_batch(self, device="mps"):
1867 B = torch.ones(5, 6).to("mps")
1868 C = torch.ones(6, 5).to("mps")
1874 device = "mps"
1914 a = torch.randn(4, 3, device="mps")
1915 b = torch.randn(4, 3, device="mps")
1923 A = torch.ones(5, 5).to("mps")
1924 B = torch.ones(5, 6).to("mps")
1925 C = torch.ones(6, 5).to("mps")
1933 batch1_mps = batch1_cpu.detach().clone().to("mps")
1934 batch2_mps = batch2_cpu.detach().clone().to("mps")
1945 batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps')
1946 batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps')
1957 A = torch.ones(5, 10).to("mps")
1958 B = torch.ones(5).to("mps")
1959 C = torch.ones(10).to("mps")
1965 M_mps = M_cpu.detach().clone().to("mps")
1978 M_mps = M_cpu.detach().clone().to("mps")
1979 batch1_mps = batch1_cpu.detach().clone().to("mps")
1980 batch2_mps = batch2_cpu.detach().clone().to("mps")
1996 M_mps = M_cpu.detach().clone().to("mps")
1997 batch1_mps = batch1_cpu.detach().clone().to("mps")
1998 batch2_mps = batch2_cpu.detach().clone().to("mps")
2012 y_mps = x_cpu.to("mps")
2019 x_mps = x.to('mps')
2020 projected_mps = projected.to('mps')
2028 x_mps = x.to('mps')
2029 projected_mps = projected.to('mps')
2042 device = "mps"
2052 # Mixed CPU<->MPS tensors
2057 torch.nn.functional.linear(torch.rand(size, device='mps'),
2058 torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
2061 with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
2062 torch.nn.functional.linear(torch.rand(size, device='mps'),
2066 with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
2068 torch.rand(size, device='mps'))
2072 …near = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
2075 mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
2078 mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
2080 linear_mps_input = torch.randn(shape).to('mps')
2095 grad = cpu_grad.detach().to('mps').requires_grad_()
2111 x_grad_out_mps = x_grad_out.to("mps")
2113 w_grad_out_mps = w_grad_out.to("mps")
2121 b_grad_out_mps = b_grad_out.to("mps")
2201 def test_randperm(self, device="mps"):
2244 x = cpu_x.detach().clone().to('mps').requires_grad_()
2254 grad = cpu_grad.to('mps')
2266 grad = cpu_grad.to('mps')
2312 x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
2338 x_mps = torch.randn(shape, device="mps")
2341 mask_mps = torch.rand(shape, device="mps") < 0.6
2344 y_mps = torch.randn(shape, device="mps")
2358 device = "mps"
2378 x_mps = torch.randn(shape, device="mps")
2380 mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
2401 x = cpu_x.detach().clone().to('mps').requires_grad_()
2420 x = cpu_x.detach().clone().to('mps').requires_grad_()
2432 running_mean = cpu_running_mean.detach().clone().to('mps')
2433 running_var = cpu_running_var.detach().clone().to('mps')
2441 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2443 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2477 device='mps')
2490 device='mps')
2503 device='mps')
2528 grad = cpu_grad.to('mps')
2571 inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
2572 x = torch.nn.BatchNorm2d(8).to("mps")
2573 y = torch.nn.BatchNorm2d(8).to("mps")
2583 bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
2586 x_mps = x_cpu.to('mps')
2594 inputs = torch.rand(4, 4, device="mps", requires_grad=True)
2595 x = torch.nn.LayerNorm(4).to("mps")
2596 y = torch.nn.LayerNorm(4).to("mps")
2604 a = torch.arange(9, dtype=torch.float, device="mps") - 4
2626 c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
2641 d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
2653 x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps")
2660 a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4
2684 x = cpu_x.detach().clone().to('mps').requires_grad_()
2687 …erNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
2689 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2691 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2703 grad = cpu_grad.to('mps')
2720 …torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype…
2725 device = torch.device("mps")
2738 freq_mps = torch.fft.fftfreq(10**4, device='mps')
2751 x = cpu_x.detach().clone().to('mps').requires_grad_()
2763 running_mean = cpu_running_mean.detach().clone().to('mps')
2764 running_var = cpu_running_var.detach().clone().to('mps')
2772 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2774 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2806 device='mps')
2819 device='mps')
2832 device='mps')
2857 grad = cpu_grad.to('mps')
2910 grad = cpu_grad.to('mps')
2923 x = cpu_x.detach().clone().to('mps').requires_grad_()
2926 weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2929 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2932 linear = torch.nn.Linear(5, 10, device='mps')
2944 x = cpu_x.detach().clone().to('mps').requires_grad_()
2947 conv = torch.nn.Conv2d(3, 3, 3, device='mps')
2958 x = cpu_x.detach().clone().to('mps').requires_grad_()
2961 conv = torch.nn.Conv3d(3, 3, 3, device='mps')
2996 x = cpu_x.detach().clone().to('mps').requires_grad_()
2999 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3006 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3014 grad = cpu_grad.to('mps')
3070 x = cpu_x.detach().clone().to('mps').requires_grad_()
3073 wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3080 bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3089 grad = cpu_grad.to('mps')
3134 x = cpu_x.detach().clone().to('mps').requires_grad_()
3142 grad = cpu_grad.to('mps')
3159 x = cpu_x.detach().clone().to('mps').requires_grad_()
3167 grad = cpu_grad.to('mps')
3184 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
3190 mps_grad = cpu_grad.to('mps')
3210 x = cpu_x.detach().clone().to('mps')
3212 y = cpu_y.detach().clone().to('mps')
3220 x = cpu_x.detach().clone().to('mps')
3230 y = cpu_y.detach().clone().to('mps')
3247 x = cpu_x.detach().clone().to('mps')
3250 y = cpu_y.detach().clone().to('mps')
3253 z = cpu_z.detach().clone().to('mps')
3284 mps_x = cpu_x.detach().clone().to('mps')
3285 mps_y = cpu_y.detach().clone().to('mps')
3286 mps_z = cpu_z.detach().clone().to('mps')
3287 mps_out = cpu_out.detach().clone().to('mps')
3312 x_mps = x.detach().clone().to(device="mps")
3313 y_mps = y.detach().clone().to(device="mps")
3314 z_mps = z.detach().clone().to(device="mps")
3318 result_mps_out = result_cpu.detach().clone().to('mps')
3335 mps_A = cpu_A.to('mps')
3336 mps_F = cpu_F.to('mps')
3342 mps_x = torch.tensor(values, device='mps')
3351 x = torch.tensor(1).expand([10]).to("mps")
3359 a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3363 a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3364 b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
3370 x = torch.randn([1, 4, 4], device="mps")
3381 x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
3398 lin_mps = nn.Linear(10, 256, device="mps")
3401 lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
3402 lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
3404 x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
3412 cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
3433 x_mps = x_cpu.detach().clone().to("mps")
3440 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3459 a_mps = a.to("mps")
3467 mps_in = cpu_in.detach().clone().to("mps")
3476 x_mps = torch.randn(1, 4800, 2, device="mps")
3488 t_mps = torch.randn(shape, device="mps")
3542 x = torch.randn(2, 3, 3, device="mps")
3560 idx_mps = idx.to("mps")
3561 pts_mps = pts.to("mps")
3566 actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
3576 tensor = torch.randint(10, shape, device="mps")
3578 torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
3584 mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
3606 x_mps = x_tensor.detach().clone().to("mps")
3624 # Tests for previously encountered MPS bugs
3661 x_mps = x.detach().clone().to("mps")
3674 tensor_list = torch.tensor([1.0, 1.2], device="mps")
3684 tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
3699 t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3739 x_mps = torch.tensor([0.5, 0.5], device="mps")
3758 …# when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking int…
3760 b = b_cpu.to('mps')
3761 c = c_cpu.to('mps')
3774 mps_x = torch.tensor(values, device='mps')
3793 x = cpu_x.detach().clone().to('mps').requires_grad_()
3799 grad = cpu_grad.to('mps')
3812 def test_torch_repeat_interleave(self, device="mps"):
3819 lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
3838 def test_repeat_interleave(self, device="mps"):
3891 x = torch.randn(shape, dtype=dtype, device="mps")
3900 helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
3901 helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
3902 helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
3903 helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
3904 helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
3913 mps_x = torch.tensor(n, dtype=dtype).to('mps')
3953 mps_x = torch.tensor(values, device='mps')
3959 self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
3961 self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
3962 self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
3963 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
3965 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
3967 self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
3969 self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
3971 self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
3972 torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
3978 x_mps = x_cpu.to('mps')
3983 device = 'mps'
3998 x = torch.rand(32, 1, device='mps')
4009 a_mps = a_cpu.to(torch.device('mps'))
4027 a_mps = a_cpu.to(torch.device('mps'))
4055 t_mps = t.to("mps")
4062 x = torch.full((3, 3), True, device='mps')
4064 y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
4073 x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
4079 x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
4080 y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
4084 x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps')
4085 y = torch.tensor([0, 1], dtype=torch.int8, device='mps')
4090 x = torch.tensor([[]], device='mps')
4140 def test_unique_all_dtypes(self, device="mps"):
4200 x = cpu_x.detach().clone().to('mps')
4214 x = torch.arange(2, device="mps")
4220 x = cpu_x.detach().clone().to('mps')
4260 mps_data = data.to("mps")
4273 t_mps = torch.tensor(a, device="mps")
4280 y = x.to('mps')
4285 y = x.to('mps')
4289 y = torch.full((4, 4, 4, 4), 13, device="mps")
4292 # As y is on MPS and z on CPU, this dispatches to a copy operator
4299 x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
4301 update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
4314 mps_src = cpu_src.to("mps")
4315 mps_dst = cpu_dst.to("mps")
4333 self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
4339 # Example values for all dtypes supported on the MPS backend
4357 # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4358 # (torch.tensor(val2, dtype=dtype2, device='mps')))
4362 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4363 (torch.tensor(val2, dtype=dtype2, device='mps')),
4367 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4368 (torch.tensor([val2], dtype=dtype2, device='mps')),
4372 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4373 (torch.tensor([val2], dtype=dtype2, device='mps')),
4377 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4378 (torch.tensor(val2, dtype=dtype2, device='mps')),
4382 x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
4383 y1 = torch.tensor(val2, dtype=dtype2, device='mps')
4387 x3 = torch.tensor(val1, dtype=dtype1, device='mps')
4388 y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
4393 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4394 (torch.full(full_shape, val2, dtype=dtype2, device='mps')),
4415 x_mps = x_cpu.to('mps')
4416 actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
4425 # Test MPS
4445 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4458 self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
4464 t_mps = a.to("mps").cumsum(0)
4476 x = cpu_x.detach().clone().to('mps')
4487 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4500 self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
4511 x = cpu_x.detach().clone().to('mps')
4523 x = cpu_x.detach().clone().to('mps')
4535 a = torch.tensor(1., device="mps", requires_grad=True)
4547 ones1 = torch.tensor(values_1, device='mps')
4548 x = cpu_x.detach().clone().to('mps').requires_grad_()
4558 mps_x = cpu_x.to('mps')
4569 x_mps = torch.arange(1., 8, device="mps")
4579 x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
4583 x = torch.tensor(0.5, device="mps")
4587 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
4588 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
4589 self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
4592 input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
4594 weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
4606 device = "mps"
4696 device = "mps"
4722 x = cpu_x.detach().clone().to('mps').requires_grad_()
4742 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4743 targetMPS = targetCPU.detach().clone().to('mps')
4771 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4772 targetMPS = targetCPU.detach().clone().to('mps')
4798 model_mps = copy.deepcopy(model_cpu).to("mps")
4803 x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
4813 y_hat_mps = y_hat.detach().clone().to("mps")
4830 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4831 targetMPS = targetCPU.detach().clone().to('mps')
4853 target = torch.ones(5, device='mps')
4854 input = torch.ones(5, device='mps')
4857 target = torch.zeros(5, device='mps')
4858 input = torch.zeros(5, device='mps')
4863 a = torch.rand(25, device='mps')
4864 b = torch.rand(25, 1, device='mps')
4871 target = torch.rand(x_size, y_size, device='mps')
4874 output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
4879 weight = torch.rand(y_size, device='mps')
4891 grad = torch.rand(x_size, y_size, device='mps')
4901 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4902 target = torch.zeros(3, 1, device='mps')
4904 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4908 target = torch.rand(16, 4, device='mps')
4909 output = torch.rand(16, 4, device='mps') - 0.5
4911 weight = torch.rand(4, device='mps')
4919 weight = torch.rand(16, 1, device='mps')
4928 target = torch.rand(64, 4, device='mps')
4929 output = torch.rand(64, 4, device='mps') - 0.5
4930 pos_weight = torch.ones(64, 4, device='mps')
4936 target = torch.rand(64, 4, device='mps')
4937 output = torch.rand(64, 4, device='mps') - 0.5
4938 pos_weight = torch.rand(4, device='mps')
4951 output = torch.zeros(3, 1, requires_grad=True, device='mps')
4952 target = torch.zeros(3, 1, device='mps')
4953 pos_weight = torch.ones(3, 1, device='mps')
4955 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4960 output = torch.tensor([0., -120.], device='mps')
4961 target = torch.tensor([0., 1.], device='mps')
4962 pos_weight = torch.tensor([1., 1.], device='mps')
4972 target = torch.rand(16, 4, device='mps')
4973 output = torch.rand(16, 4, device='mps') - 0.5
4975 weight = torch.rand(4, device='mps')
4983 weight = torch.rand(16, 1, device='mps')
4994 pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps')
4995 target = torch.ones(3, dtype=torch.long, device='mps')
5002 mps_x = torch.tensor(values, device='mps', requires_grad=True)
5009 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5022 mps_x = torch.tensor(values, device='mps', requires_grad=True)
5029 mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5039 mps_x = torch.tensor(values1, device='mps')
5040 mps_y = torch.tensor(values2, device='mps')
5051 mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
5060 mps_x = torch.tensor(values1, device='mps')
5061 mps_y = torch.tensor(values2, device='mps')
5073 mps_x = cpu_x.detach().clone().to('mps')
5074 mps_y = cpu_y.detach().clone().to('mps')
5085 mps_x = cpu_x.detach().clone().to('mps')
5097 mps_x = cpu_x.detach().clone().to('mps')
5098 mps_y = cpu_y.detach().clone().to('mps')
5109 mps_x = cpu_x.detach().clone().to('mps')
5121 mps_x = cpu_x.detach().clone().to('mps')
5122 mps_y = cpu_y.detach().clone().to('mps')
5133 mps_x = cpu_x.detach().clone().to('mps')
5145 mps_x = cpu_x.detach().clone().to('mps')
5146 mps_y = cpu_y.detach().clone().to('mps')
5157 mps_x = cpu_x.detach().clone().to('mps')
5169 mps_x = cpu_x.detach().clone().to('mps')
5170 mps_y = cpu_y.detach().clone().to('mps')
5181 mps_x = cpu_x.detach().clone().to('mps')
5194 mps_tensor = cpu_tensor.to(torch.device('mps'))
5199 mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32)
5218 x = cpu_x.detach().clone().to('mps')
5221 x = cpu_x.detach().clone().to('mps')
5224 x = cpu_x.detach().clone().to('mps').requires_grad_()
5273 …ps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
5287 x = cpu_x.detach().clone().to('mps')
5290 x = cpu_x.detach().clone().to('mps')
5293 x = cpu_x.detach().clone().to('mps')
5306 y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
5307 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5313 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
5314 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5320 y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
5321 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5327 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
5328 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5334 y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
5335 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5341 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
5342 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5348 y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
5349 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5355 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
5356 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5369 mps_x = cpu_x.detach().clone().to('mps')
5385 mps_x = cpu_x.detach().clone().to('mps')
5426 x = cpu_x.detach().clone().to('mps')
5476 x_mps = fn(torch.zeros(shape, device="mps"), dim=dim)
5484 x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu()
5508 x = cpu_x.detach().clone().to('mps')
5555 x_mps = x_cpu.to("mps")
5562 x = cpu_x.detach().clone().to('mps')
5573 y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
5574 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5585 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
5586 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5597 y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
5598 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5609 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
5610 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5621 y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
5622 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5633 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
5634 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5645 y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
5646 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5657 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
5658 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5673 x = cpu_x.detach().clone().to('mps')
5676 x = cpu_x.detach().clone().to('mps')
5679 x = cpu_x.detach().clone().to('mps').requires_grad_()
5731 x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2)
5742 x = cpu_x.detach().clone().to('mps')
5745 x = cpu_x.detach().clone().to('mps')
5748 x = cpu_x.detach().clone().to('mps').requires_grad_()
5773 x = cpu_x.detach().clone().to('mps').requires_grad_()
5826 x = cpu_x.detach().clone().to('mps')
5929 x = cpu_x.detach().clone().to('mps')
5970 x = cpu_x.detach().clone().to('mps').requires_grad_()
5976 grad = cpu_grad.to('mps')
5992 x = cpu_x.detach().clone().to('mps').requires_grad_()
5998 grad = cpu_grad.to('mps')
6015 mps_x = cpu_x.detach().clone().to('mps')
6016 mps_y = cpu_y.detach().clone().to('mps')
6030 x = cpu_x.detach().clone().to('mps')
6034 clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16)
6035 clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10
6045 t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
6067 x = cpu_x.detach().clone().to('mps')
6070 min_t = cpu_min_t.detach().clone().to('mps')
6089 x = cpu_x.detach().clone().to('mps')
6092 max_t = cpu_max_t.detach().clone().to('mps')
6116 x = cpu_x.detach().clone().to('mps')
6121 min_t = cpu_min_t.detach().clone().to('mps')
6126 max_t = cpu_max_t.detach().clone().to('mps')
6191 mps_x = cpu_x.detach().clone().to('mps')
6193 mps_y = cpu_y.detach().clone().to('mps')
6212 mps_x = cpu_x.detach().clone().to('mps')
6237 …h.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dt…
6243 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
6250 x = cpu_x.detach().clone().to('mps')
6263 x = x_cpu.detach().clone().to('mps')
6269 x = cpu_x.detach().clone().to('mps').requires_grad_()
6288 device = 'mps'
6311 inputMPS = inputCPU.detach().to('mps').requires_grad_()
6344 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6376 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6420 x = cpu_x.detach().clone().to('mps')
6423 y = cpu_y.detach().clone().to('mps')
6426 z = cpu_z.detach().clone().to('mps')
6454 x = cpu_x.detach().clone().to('mps')
6456 y = cpu_y.detach().clone().to('mps')
6458 z = cpu_z.detach().clone().to('mps')
6461 x = cpu_x.detach().clone().to('mps')
6463 y = cpu_y.detach().clone().to('mps')
6465 z = cpu_z.detach().clone().to('mps')
6468 x = cpu_x.detach().clone().to('mps').requires_grad_()
6470 y = cpu_y.detach().clone().to('mps').requires_grad_()
6472 z = cpu_z.detach().clone().to('mps').requires_grad_()
6491 x = cpu_x.detach().clone().to('mps')
6503 x = cpu_x.detach().clone().to('mps')
6515 x = cpu_x.detach().clone().to('mps')
6527 x = cpu_x.detach().clone().to('mps')
6539 x = cpu_x.detach().clone().to('mps')
6551 x = cpu_x.detach().clone().to('mps')
6554 y = cpu_y.detach().clone().to('mps')
6566 x = cpu_x.detach().clone().to('mps')
6569 y = cpu_y.detach().clone().to('mps')
6581 x = cpu_x.detach().clone().to('mps')
6595 x = cpu_x.detach().clone().to('mps')
6598 y = cpu_y.detach().clone().to('mps')
6601 z = cpu_z.detach().clone().to('mps')
6604 w = cpu_w.detach().clone().to('mps')
6613 x = cpu_x.detach().clone().to('mps')
6616 y = cpu_y.detach().clone().to('mps')
6619 z = cpu_z.detach().clone().to('mps')
6638 x = cpu_x.detach().clone().to('mps')
6651 x = cpu_x.detach().clone().to('mps').requires_grad_()
6657 grad = cpu_grad.to('mps')
6671 x = cpu_x.detach().clone().to('mps').requires_grad_()
6677 grad = cpu_grad.to('mps')
6693 x = cpu_x.detach().clone().to('mps').requires_grad_(True)
6699 grad = cpu_grad.to('mps')
6722 F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
6729 x = cpu_x.detach().clone().to('mps').requires_grad_()
6736 grad = cpu_grad.to('mps')
6752 x = cpu_x.detach().clone().to('mps').requires_grad_()
6758 grad = cpu_grad.to('mps')
6780 x = cpu_x.detach().clone().to('mps').requires_grad_()
6786 grad = cpu_grad.to('mps')
6801 input_cast_mps = input.to('mps')
6812 input_mps = input_cpu.to('mps')
6823 x = cpu_x.detach().clone().to('mps').requires_grad_()
6829 grad = cpu_grad.to('mps')
6842 x = cpu_x.detach().clone().to('mps').requires_grad_()
6848 grad = cpu_grad.to('mps')
6891 x = cpu_x.detach().clone().to('mps').requires_grad_()
6904 grad = cpu_grad.to('mps')
6926 x = cpu_x.detach().clone().to('mps')
6942 grad = cpu_grad.to('mps')
6961 …tRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
6966 x = cpu_x.detach().clone().to('mps')
6981 grad = cpu_grad.to('mps')
7005 devices += ['mps']
7032 x = cpu_x.detach().clone().to('mps')
7048 x = cpu_x.detach().clone().to('mps').requires_grad_()
7051 x = cpu_x.detach().clone().to('mps')
7060 grad = cpu_grad.to('mps')
7076 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
7087 mps_grad = cpu_grad.to('mps')
7107 mps_x = torch.tensor(values, device='mps')
7108 mps_x1 = torch.tensor(values1, device='mps')
7117 mps_x = torch.tensor(values, device='mps')
7136 mps_x = torch.tensor(values, device='mps')
7166 x = cpu_x.detach().clone().to('mps').requires_grad_()
7172 grad = cpu_grad.to('mps')
7184 x = cpu_x.clone().to('mps')
7199 x = cpu_x.detach().clone().to('mps').requires_grad_()
7205 grad = cpu_grad.to('mps')
7216 x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
7225 x = cpu_x.detach().clone().to('mps')
7228 idx = cpu_idx.detach().clone().to('mps')
7231 source = cpu_source.detach().clone().to('mps')
7255 torch.mps.empty_cache()
7257 x = torch.rand(16000, 67120, device="mps")
7259 idx = torch.arange(0, 2, device="mps")
7265 torch.mps.empty_cache()
7269 x = torch.rand(10, 1, device="mps")
7270 y = torch.rand(1, 32769, device="mps")
7277 x = torch.rand(m, n, device="mps", dtype=dtype)
7278 y = torch.rand(n, k, device="mps", dtype=dtype)
7283 … # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
7298 y = x.to(device="mps")
7299 self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps")))
7307 x = cpu_x.detach().clone().to('mps')
7331 x = cpu_x.detach().clone().to('mps')
7334 idx = cpu_idx.detach().clone().to('mps')
7353 x = cpu_x.detach().clone().to('mps')
7356 idx = cpu_idx.detach().clone().to('mps')
7369 embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
7371 W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
7372 idx_MPS = torch.tensor(idx, device='mps')
7403 x = cpu_x.detach().clone().to('mps').requires_grad_()
7409 idx = cpu_idx.detach().clone().to('mps')
7415 grad = cpu_grad.to('mps')
7436 x = cpu_x.detach().clone().to('mps').requires_grad_()
7441 idx = cpu_idx.detach().clone().to('mps')
7447 grad = cpu_grad.to('mps')
7458 x = cpu_x.detach().clone().to('mps').requires_grad_()
7461 src = cpu_src.detach().clone().to('mps').requires_grad_()
7475 idx = cpu_idx.detach().clone().to('mps')
7492 grad = cpu_grad.to('mps')
7525 x = cpu_x.detach().clone().to('mps').requires_grad_()
7528 src = cpu_src.detach().clone().to('mps').requires_grad_()
7534 idx = cpu_idx.detach().clone().to('mps')
7550 grad = cpu_grad.to('mps')
7565 x = cpu_x.detach().clone().to('mps').requires_grad_()
7568 src = cpu_src.detach().clone().to('mps').requires_grad_()
7574 idx = cpu_idx.detach().clone().to('mps')
7600 self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
7601 self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
7602 self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
7603 self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
7609 x = cpu_x.detach().clone().to('mps').requires_grad_()
7615 grad = cpu_grad.to('mps')
7635 mps_input = cpu_input.to('mps')
7650 x = cpu_x.detach().clone().to('mps').requires_grad_()
7656 grad = cpu_grad.to('mps')
7680 result = torch.eye(n, dtype=dtype, device='mps')
7683 result = torch.eye(n, m, device='mps')
7699 x = cpu_x.detach().clone().to('mps').requires_grad_()
7705 # grad = cpu_grad.to('mps')
7721 result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
7732 self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
7733 self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
7734 … self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
7735 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
7738 out_mps = torch.tensor([], device="mps")
7747 self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
7748 self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
7749 …(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
7750 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
7759 x = cpu_x.detach().clone().to('mps').requires_grad_()
7770 grad = cpu_grad.to('mps')
7781 x = cpu_x.detach().clone().to('mps').requires_grad_()
7787 grad = cpu_grad.to('mps')
7806 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
7816 cond = cpu_cond.detach().clone().to('mps')
7819 x = cpu_x.detach().clone().to('mps').requires_grad_()
7822 y = cpu_y.detach().clone().to('mps').requires_grad_()
7828 grad = cpu_grad.to('mps')
7850 # TODO: Remove me when out OpInfo testing is enabled on MPS
7851 output = torch.tensor(0.0, device="mps")
7852 cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps")
7853 inp = torch.rand(3, 3, device="mps")
7854 other = torch.rand(3, 3, device="mps")
7862 mps_out = torch.normal(mean, std, shape, device='mps')
7867 mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
7872 std_tensor = cpu_std_tensor.detach().clone().to('mps')
7875 mps_out = torch.zeros(shape, device='mps')
7878 mps_out = torch.zeros(shape, device='mps')
7881 mps_out = torch.zeros(shape, device='mps')
7900 all_ones = torch.ones(shape, device='mps')
7901 all_zeros = torch.zeros(shape, device='mps')
7921 mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5)
7925 self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
7931 # explicit manual seeding by creating an MPS Generator
7932 g_mps = torch.Generator(device='mps')
7934 mps_x = torch.randn(5, device='mps', generator=g_mps)
7937 mps_y = torch.randn(5, device='mps', generator=g_mps)
7944 mps_x = torch.randn(5, device='mps', generator=g_mps)
7953 mps_x = torch.randn(5, device='mps', generator=g_mps)
7958 # manual seeding on the "default" MPS generator using
7961 mps_x = torch.randn(5, device='mps')
7962 # manual seeding using torch.mps.manual_seed()
7963 # which should set the "default" MPS generator
7965 torch.mps.manual_seed(230)
7967 mps_y = torch.randn(5, device='mps')
7972 g_state = torch.mps.get_rng_state()
7975 mps_x = torch.randn(5, device='mps')
7979 self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2)
7984 …# restore the previously saved state to the "default" MPS generator, and the results should match …
7985 torch.mps.set_rng_state(g_state)
7986 mps_x = torch.randn(5, device='mps')
7991 # MPS stream to finish running each of them
7993 .to(device='mps', dtype=torch.float)
7995 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7996 torch.mps.synchronize()
7998 torch.mps.synchronize()
8000 torch.mps.synchronize()
8006 torch.mps.empty_cache()
8008 current_alloc_before = torch.mps.current_allocated_memory()
8013 driver_alloc_before = torch.mps.driver_allocated_memory()
8015 x = torch.ones(1024 * 1024 * 8, device="mps")
8017 current_alloc_after = torch.mps.current_allocated_memory()
8018 driver_alloc_after = torch.mps.driver_allocated_memory()
8025 max_memory = torch.mps.recommended_max_memory()
8035 with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
8038 .to(device='mps', dtype=torch.float)
8039 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8042 torch.mps.profiler.start(mode="interval", wait_until_completed=True)
8044 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8046 torch.mps.profiler.stop()
8049 startEvent = torch.mps.Event(enable_timing=True)
8052 .to(device='mps', dtype=torch.float)
8053 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8055 endEvent = torch.mps.Event(enable_timing=True)
8062 m.x = torch.rand(3, 3, device='mps')
8073 mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
8088 x = torch.empty(10, 10, dtype=dtype, device='mps')
8096 mps_out = torch.zeros(shape, device='mps', dtype=dtype)
8128 mps_x = cpu_x.detach().clone().to('mps')
8131 mps_y = cpu_y.detach().clone().to('mps')
8142 mps_s = cpu_s.detach().clone().to('mps')
8147 mps_s = cpu_s.detach().clone().to('mps')
8165 x = cpu_x.detach().clone().to('mps')
8168 y = cpu_y.detach().clone().to('mps')
8183 x = torch.ones(4, dtype=torch.int32, device='mps')
8184 self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
8185 self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
8190 …rch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], …
8194 …ps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="
8201 mps_x = cpu_x.detach().clone().to('mps')
8206 mps_x = cpu_x.to('mps')
8211 mps_x = cpu_x.detach().clone().to('mps')
8218 mps_x = cpu_x.detach().clone().to('mps')
8238 mps_x = cpu_x.detach().clone().to('mps')
8243 mps_x = cpu_x.to('mps')
8256 mps_x = cpu_x.detach().clone().to('mps')
8265 mps_x = cpu_x.detach().clone().to('mps')
8267 mps_y = cpu_y.detach().clone().to('mps')
8282 input_mps = input_cpu.detach().clone().to("mps")
8285 other_mps = other_cpu.detach().clone().to("mps")
8300 prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
8318 x = torch.rand((3, 3), device="mps")
8325 x = torch.rand((3, 3), device="mps")
8338 x = cpu_x.detach().clone().to('mps')
8356 x = cpu_x.detach().clone().to('mps')
8359 other = cpu_other.detach().clone().to('mps')
8379 x = cpu_x.detach().clone().to('mps')
8382 other = cpu_other.detach().clone().to('mps')
8403 x = cpu_x.detach().clone().to('mps')
8406 other = cpu_other.detach().clone().to('mps')
8428 x = torch.randn((30, 15), device='mps', dtype=dtype)
8430 x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
8446 mps_x = cpu_x.detach().clone().to('mps')
8480 A_mps = A.clone().detach().to('mps')
8481 B_mps = B.clone().detach().to('mps')
8498 A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8499 B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16)
8504 C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8516 # MPS
8517 input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
8518 target_mps = target_cpu.detach().clone().to('mps')
8548 def test_nll_loss_mismatched_batch(self, device='mps'):
8571 output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps')
8573 for cpu, mps in zip(output_cpu, output_mps):
8574 self.assertEqual(cpu, mps)
8587 _test_nll_loss_invalid_target_dim(device='mps')
8605 _test_nll_loss_invalid_weights(device='mps')
8616 # MPS
8617 input_mps = input.detach().clone().to('mps').requires_grad_()
8618 target_mps = target.detach().clone().to('mps')
8619 weights_mps = weights.to("mps")
8636 # MPS
8637 input_mps = input.detach().clone().to('mps').requires_grad_()
8638 target_mps = target.detach().clone().to('mps')
8683 for dev in ['cpu', 'mps']:
8702 self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
8703 self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
8708 x = cpu_x.detach().clone().to('mps')
8806 device = 'mps'
8815 M_mps = M_cpu.to('mps')
8825 print(torch.ones(100, 100, device='mps').nonzero())
8826 print(torch.ones(100, 100, device='mps').nonzero().contiguous())
8937 … lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
8940 … lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
8943 def test_conv2d_valid_padding(self, device='mps'):
8954 x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True)
8955 m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps")
8956 m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps")
8966 x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
8967 m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
8968 m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
8993 A_mps = A.detach().clone().to("mps")
8996 r2 = attention2(A_mps, device="mps")
9001 def test_group_norm_backward(self, device='mps'):
9028 # def test_conv2d_same_padding(self, device='mps'):
9052 input_mps = input_cpu.detach().clone().to("mps")
9061 input_mps = input_cpu.detach().clone().to("mps")
9069 k_mps = k_cpu.detach().clone().to("mps")
9072 x_mps = x_cpu.detach().clone().to("mps")
9084 inputMPS = inputCPU.detach().clone().to('mps')
9093 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
9181 def test_addmm(self, device="mps", dtype=torch.float32):
9217 def test_addr(self, device="mps", dtype=torch.float32):
9229 def test_matrix_rank(self, device="mps", dtype=torch.float32):
9274 … "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
9277 def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
9338 … "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9345 … "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9357 a_f32 = torch.rand((m, k), device="mps")
9358 b_f32 = torch.rand((k, n), device="mps")
9364 b_int32 = b_int32.to("mps")
9365 b_scales_and_zeros = b_scales_and_zeros.to("mps")
9394 a_f32 = torch.rand((m, k), device="mps")
9395 b_f32 = torch.rand((n, k), device="mps")
9437 q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
9438 k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9439 v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9484 causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
9488 q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
9489 k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9490 v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9492 input_pos = torch.tensor([i], dtype=torch.int32, device='mps')
9517 x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
9529 s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
9547 x_mps = x_cpu.to('mps')
9560 a_mps = torch.ones((2, 2),).to(torch.device("mps"))
9561 b_mps = torch.ones((2, 2),).to(torch.device("mps"))
9585 mps_x = cpu_x.detach().clone().to('mps')
9602 if base.device.type == 'mps':
9621 def test_diagonal_view(self, device="mps"):
9629 t = torch.ones((3, 3, 3), device="mps")
9636 def test_select_view(self, device="mps") -> None:
9644 def test_unbind_view(self, device="mps") -> None:
9654 def test_expand_view(self, device="mps") -> None:
9662 def test_expand_as_view(self, device="mps"):
9671 def test_narrow_view(self, device="mps"):
9679 def test_permute_view(self, device="mps") -> None:
9687 def test_transpose_view(self, device="mps"):
9696 def test_transpose_inplace_view(self, device="mps"):
9718 def test_t_view(self, device="mps"):
9728 t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3)
9734 def test_t_inplace_view(self, device="mps"):
9742 def test_T_view(self, device="mps"):
9751 def test_unfold_view(self, device="mps"):
9759 def test_squeeze_view(self, device="mps"):
9766 def test_squeeze_inplace_view(self, device="mps"):
9774 def test_unsqueeze_view(self, device="mps"):
9782 def test_unsqueeze_inplace_view(self, device="mps"):
9790 def test_as_strided_view(self, device="mps"):
9798 def test_as_strided_inplace_view(self, device="mps"):
9806 def test_view_view(self, device="mps"):
9814 def test_view_as_view(self, device="mps"):
9823 def test_contiguous_self(self, device="mps"):
9828 def test_contiguous_nonview(self, device="mps"):
9836 def test_reshape_view(self, device="mps"):
9844 def test_reshape_as_view(self, device="mps"):
9853 def test_reshape_nonview(self, device="mps"):
9861 def test_flatten_view(self, device="mps"):
9898 def test_flatten_nonview(self, device="mps"):
9918 def test_basic_indexing_slice_view(self, device="mps"):
9926 def test_basic_indexing_ellipses_view(self, device="mps"):
9934 def test_basic_indexing_newaxis_view(self, device="mps"):
9942 def test_chunk_view(self, device="mps"):
9952 def test_split_view(self, device="mps"):
9962 def test_movedim_view(self, device="mps"):
9985 def test_view_copy(self, device="mps"):
10002 def test_view_copy_out(self, device="mps"):
10021 def test_detached_view_copy(self, device="mps"):
10030 def test_empty_reshape(self, device="mps"):
10039 def test_expand(self, device="mps"):
10072 def test_view_empty(self, device="mps"):
10076 def test_reshape(self, device="mps"):
10108 def test_narrow(self, device="mps"):
10119 def test_narrow_tensor(self, device="mps"):
10129 def test_t(self, device="mps"):
10156 def test_split(self, device="mps"):
10190 def test_chunk(self, device="mps"):
10210 def test_unsqueeze(self, device="mps") -> None:
10225 def test_big_transpose(self, device="mps"):
10231 def test_T(self, device="mps"):
10239 def test_transposes(self, device="mps", dtype=torch.float32):
10254 def test_transposes_errors(self, device="mps", dtype=torch.float32):
10264 def test_python_types(self, device="mps"):
10278 def test_resize_as_preserves_strides(self, device="mps"):
10284 def test_memory_format_resize_as(self, device="mps"):
10285 def test_helper(shape, memory_format, device="mps"):
10291 test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
10292 test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
10294 def test_memory_format_resize_(self, device="mps"):
10295 def test_helper(shape, numel, memory_format, device="mps"):
10300 test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
10301 test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
10326 def test_atleast_gradient(self, device="mps"):
10331 def test_view(self, device="mps"):
10362 def test_contiguous(self, device="mps"):
10371 def test_resize_mps_dtypes(self, device="mps"):
10378 def test_resize_as_mps_dtypes(self, device="mps"):
10385 def test_resize_overflow(self, device="mps"):
10392 def test_view_all_dtypes_and_devices(self, device="mps"):
10403 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10406 y_gpu = y_cpu.to(device='mps')
10421 a_mps = a_cpu.detach().clone().to("mps")
10422 model_mps = model_cpu.to("mps")
10433 deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
10436 y_gpu = y_cpu.to(device='mps')
10450 device = 'mps'
10461 …_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
10462 … conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
10463 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
10468 … x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
10490 a_mps = a_cpu.detach().clone().to("mps")
10491 model_mps = model_cpu.to("mps")
10501 x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
10514 …_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
10515 … conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10516 … conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10539 m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10540 m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10543 input_mps = input_cpu.detach().clone().to("mps")
10558 input_mps = input_cpu.detach().clone().to("mps")
10561 downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
10562 …downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10563 … downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10566 upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
10567 … upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10568 upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10585 y_gpu = y_cpu.to(device='mps')
10588 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10597 y_gpu = y_cpu.to(device='mps')
10600 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10663 …input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requ…
10664 grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
10667 out_mps.backward(gradients.to("mps"))
10678 … input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
10740 input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
10743 … [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
10749 … [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
10753 … [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
10758 … [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
10762 … [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
10767 … [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
10771 … [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
10779 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10783 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10788 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10792 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10797 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10801 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10809 … [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
10813 … [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
10818 … [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
10822 … [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
10827 … [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
10831 … [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
10849 device = "mps"
10859 device = "mps"
10897 device = "mps"
10930 device = "mps"
10953 device = "mps"
10959 # Test that MPS doesn't crash if nonzero called concurrently
10961 x = torch.rand(3, 3, device="mps")
10970 x = torch.rand(16, 16, device='mps', dtype=torch.float16)
10975 x_mps = x.to("mps")
10988 x_mps = x_cpu.detach().clone().to("mps")
10998 x_mps = x_cpu.detach().clone().to("mps")
11001 rows_mps = rows_cpu.detach().clone().to("mps")
11004 cols_mps = cols_cpu.detach().clone().to("mps")
11016 x_mps = x_cpu.detach().clone().to("mps")
11032 x_mps = x_cpu.detach().clone().to("mps")
11039 # MPS support binary op with uint8 natively starting from macOS 13.0
11046 x_mps = x_cpu.detach().clone().to("mps")
11077 x_mps = x_cpu.detach().clone().to("mps")
11082 out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
11123 target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
11126 indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
11129 value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
11139 def test_advancedindex_big(self, device="mps"):
11145 def test_set_item_to_scalar_tensor(self, device="mps"):
11155 def test_single_int(self, device="mps"):
11159 def test_multiple_int(self, device="mps"):
11164 def test_none(self, device="mps"):
11171 def test_step(self, device="mps"):
11179 def test_step_assignment(self, device="mps"):
11185 def test_bool_indices(self, device="mps"):
11201 def test_bool_indices_accumulate(self, device="mps"):
11208 def test_multiple_bool_indices(self, device="mps"):
11215 def test_byte_mask(self, device="mps"):
11226 def test_byte_mask_accumulate(self, device="mps"):
11235 def test_index_put_accumulate_expanded_values(self, device="mps"):
11274 def test_index_put_accumulate_non_contiguous(self, device="mps"):
11292 def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
11311 out_mps = func(t_dev, indices_dev, value0d.to("mps"))
11315 out_mps = func(t_dev, indices_dev, value1d.to("mps"))
11319 def test_index_put_accumulate_duplicate_indices(self, device="mps"):
11325 indices = delta.cumsum(0).long().to("mps")
11327 # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
11328 input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
11340 def test_index_put_deterministic(self, device="mps"):
11368 def test_multiple_byte_mask(self, device="mps"):
11378 def test_byte_mask2d(self, device="mps"):
11385 def test_jit_indexing(self, device="mps"):
11403 def test_int_indices(self, device="mps"):
11416 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
11434 …[helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.boo…
11436 def test_int_indices2d(self, device="mps"):
11443 def test_int_indices_broadcast(self, device="mps"):
11451 def test_empty_index(self, device="mps"):
11465 def test_empty_ndim_index(self, device="mps"):
11479 def test_empty_ndim_index_bool(self, device="mps"):
11483 def test_empty_slice(self, device="mps"):
11492 def test_index_getitem_copy_bools_slices(self, device="mps"):
11506 def test_index_setitem_bools_slices(self, device="mps"):
11533 def test_index_scalar_with_bool_mask(self, device="mps"):
11544 def test_setitem_expansion_error(self, device="mps"):
11555 def test_getitem_scalars(self, device="mps"):
11579 def test_setitem_scalars(self, device="mps"):
11603 def test_basic_advanced_combined(self, device="mps"):
11619 def test_int_assignment(self, device="mps"):
11628 def test_byte_tensor_assignment(self, device="mps"):
11642 def test_variable_slicing(self, device="mps"):
11648 def test_ellipsis_tensor(self, device="mps"):
11657 def test_invalid_index(self, device="mps"):
11661 def test_out_of_bound_index(self, device="mps"):
11670 def test_zero_dim_index(self, device="mps"):
11680 def test_cpu_indices(self, device="mps"):
11691 def test_nextafter(self, device="mps"):
11698 # greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051
11780 … f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
11793 def test_lstm_forward(self, device="mps", dtype=torch.float32):
11798 def test_lstm_backward(self, device="mps", dtype=torch.float32):
11805 cell = cell_module(input_size, hidden_size, device='mps')
11816 input = torch.randn(3, input_size, device='mps')
11817 bad_hx = torch.randn(1, hidden_size, device='mps')
11818 good_hx = torch.randn(3, hidden_size, device='mps')
11835 input = torch.randn(3, 10, device='mps')
11836 hx = torch.randn(3, 20, device='mps')
11837 cx = torch.randn(3, 20, device='mps')
11838 lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
11845 input = torch.randn(3, 11, device='mps')
11846 hx = torch.randn(3, 20, device='mps')
11847 cx = torch.randn(3, 20, device='mps')
11848 lstm = nn.LSTMCell(10, 20, device='mps')
11852 input = torch.randn(3, 10, device='mps')
11853 hx = torch.randn(3, 21, device='mps')
11854 cx = torch.randn(3, 20, device='mps')
11855 lstm = nn.LSTMCell(10, 20, device='mps')
11861 # TODO: Remove once test_testing.py is running on MPS devices
11875 [torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {},
11876 "torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))")
11881 … with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
11927 a = torch.ones(1, device="mps")
11928 b = torch.zeros(1, device="mps")
11942 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11943 a = torch.ones(2, dtype=torch.float64, device="mps")
11945 a = torch.ones(2, device="mps")
11946 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11950 a = torch.ones(2, device="mps")
11956 # Ensures that cpu Tensor can be loaded on mps
11962 x2 = torch.load(f, map_location="mps")
11965 self.assertEqual(x2.device.type, "mps")
11967 # Ensures that mps Tensors can be loaded on mps
11969 x = torch.rand(2, device="mps")
11976 self.assertEqual(x2.device.type, "mps")
11978 # Ensures that mps Tensors can be loaded on cpu
11980 x = torch.rand(2, device="mps")
11989 # Ensures that `mps:0` Tensors can be loaded on mps
11991 x = torch.rand(2, device="mps:0")
11995 x2 = torch.load(f, map_location="mps:0")
11998 self.assertEqual(x2.device.type, "mps")
12129 …lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else…
12171 …lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else…
12220 mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
12236 self.assertEqual(device, "mps:0")
12262 # Allocate tensors on mps
12263 with torch.device("mps"):
12265 self.assertTrue(all(x.device.type == "mps" for x in inputs))
12278 # Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
12301 # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
12302 # MPS still requires some fairly heavy special casing in the test framework.
12303 # When MPS becomes more consistent, this can probably be merged with that test using
12306 # MPS only supports float32
12309 …ike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
12336 # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
12337 # This requires mps to be properly registered in the device generic test framework which is not the
12341 instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
12342 instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
12343 instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")