1# Owner(s): ["module: mps"] 2 3import io 4import platform 5import sys 6import math 7import random 8import unittest 9import warnings 10import subprocess 11import tempfile 12import os 13import copy 14import gc 15import threading 16import torch 17import torch.nn as nn 18import torch.nn.functional as F 19import itertools 20from collections import defaultdict 21from torch import inf 22from torch.nn import Buffer, Parameter 23from torch.testing._internal import opinfo 24from torch.testing._internal.common_utils import \ 25 (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI, 26 NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests) 27from torch.testing import make_tensor 28from torch.testing._internal.common_dtype import get_all_dtypes, integral_types 29import torch.backends.mps 30from torch.distributions import Uniform, Exponential 31from functools import partial 32 33from torch.testing._internal.common_methods_invocations import ( 34 op_db, 35 DecorateInfo, 36 UnaryUfuncInfo, 37 ReductionOpInfo, 38 SpectralFuncInfo, 39 BinaryUfuncInfo, 40) 41from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes 42from torch.testing._internal.common_nn import NNTestCase 43from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel 44import numpy as np 45import torch 46import torch.utils._pytree as pytree 47from itertools import product 48import operator 49 50test_consistency_op_db = copy.deepcopy(op_db) 51test_error_inputs_op_db = copy.deepcopy(op_db) 52 53# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref` 54_ref_test_ops = tuple( 55 filter( 56 lambda op: not isinstance( 57 op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo) 58 ) 59 and op.ref is not None, 60 op_db, 61 ) 62) 63 64def xfailIf(condition): 65 def wrapper(func): 66 if condition: 67 return unittest.expectedFailure(func) 68 else: 69 return func 70 return wrapper 71 72def xfailIfMacOS14_4Plus(func): 73 return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821 74 75def mps_ops_grad_modifier(ops): 76 XFAILLIST_GRAD = { 77 78 # precision issues 79 'special.polygammaspecial_polygamma_n_0': [torch.float16], 80 'polygammapolygamma_n_0': [torch.float16], 81 'nn.functional.binary_cross_entropy': [torch.float16], 82 83 # Unimplemented ops 84 '__getitem__': [torch.float16], 85 '_segment_reduce': [torch.float16, torch.float32], 86 '_chunk_cat': [torch.float16, torch.float32], 87 'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented 88 'unfold': [torch.float16, torch.float32], 89 'sparse.mmreduce': [torch.float32], # csr not supported 90 'unique_consecutive': [torch.float16, torch.float32], 91 'special_modified_bessel_i0': [torch.float16, torch.float32], 92 'scalar_tensor': [torch.float16, torch.float32], 93 'cdist': [torch.float32], 94 'masked.scatter': [torch.float16, torch.float32], 95 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`. 96 'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`. 97 'aminmax': [torch.float32, torch.float16], 98 99 # Correctness issues 100 'atanh': [torch.float32], 101 102 # Random output 103 'exponential': [torch.float16, torch.float32], 104 105 # CPU errors 106 # derivative for aten::nextafter is not implemented on CPU 107 'nextafter': None, 108 # derivative for aten::floor_divide is not implemented on CPU 109 'floor_divide': [torch.float16, torch.float32], 110 # derivative for aten::narrow_copy is not implemented on CPU 111 'narrow_copy': [torch.float16, torch.float32], 112 # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU 113 'histogramdd': [torch.float16, torch.float32], 114 # derivative for aten::histogram is not implemented 115 'histogram': [torch.float16, torch.float32], 116 # 'bool' object is not iterable 117 'allclose': [torch.float16, torch.float32], 118 'equal': [torch.float16, torch.float32], 119 # 'float' object is not iterable 120 'item': [torch.float16, torch.float32], 121 # "mse_backward_cpu_out" not implemented for 'Half' 122 'nn.functional.mse_loss': [torch.float16], 123 # "smooth_l1_backward_cpu_out" not implemented for 'Half' 124 'nn.functional.smooth_l1_loss': [torch.float16], 125 # cpu error: grad requires non-empty inputs 126 'randn': [torch.float16, torch.float32], 127 'signal.windows.bartlett': [torch.float32], 128 'signal.windows.blackman': [torch.float32], 129 'signal.windows.cosine': [torch.float32], 130 'signal.windows.exponential': [torch.float32], 131 'signal.windows.gaussian': [torch.float32], 132 'signal.windows.general_cosine': [torch.float32], 133 'signal.windows.general_hamming': [torch.float32], 134 'signal.windows.hamming': [torch.float32], 135 'signal.windows.hann': [torch.float32], 136 'signal.windows.kaiser': [torch.float32], 137 'signal.windows.nuttall': [torch.float32], 138 'eye': [torch.float16, torch.float32], 139 140 # trunc_tensor not working properly for float16 141 'divtrunc_rounding': [torch.float16], 142 'fmod': [torch.float16], 143 144 # round not working properly for float16 145 'round': [torch.float16], 146 147 # atomic operation in backward pass 148 '_unsafe_masked_index': [torch.float16], 149 '_unsafe_masked_index_put_accumulate': [torch.float16], 150 } 151 152 MACOS_12_3_XFAILLIST_GRAD = { 153 # Unsupported Border padding mode, forward pass success as fallback to cpu 154 'grid_sampler_2d': [torch.float32], 155 # Unimplemented 156 'logaddexp2': [torch.float32], 157 158 } 159 160 MACOS_BEFORE_13_3_XFAILLIST_GRAD = { 161 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 162 'masked.softmin': [torch.float32, torch.float16], 163 'masked.softmax': [torch.float32, torch.float16], 164 'masked.log_softmax': [torch.float32, torch.float16], 165 166 # Unsupported Border padding mode, forward pass success as fallback to cpu 167 'grid_sampler_2d': [torch.float32], 168 169 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). 170 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. 171 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. 172 # Running `msort` with stable `sort` passes. 173 'msort': [torch.float16], 174 } 175 176 SKIPLIST_GRAD = { 177 'nn.functional.pairwise_distance': [torch.float16], 178 # failed assertion `destination datatype must be fp32' 179 'nn.functional.conv1d': [torch.float16], 180 'nn.functional.conv2d': [torch.float16], 181 'nn.functional.conv3d': [torch.float16], 182 'nn.functional.conv_transpose1d': [torch.float16], 183 'nn.functional.conv_transpose2d': [torch.float16], 184 'nn.functional.conv_transpose3d': [torch.float16], 185 } 186 187 MACOS_13_3_XFAILLIST_GRAD = { 188 # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour). 189 # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU. 190 # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS. 191 # Running `msort` with stable `sort` passes. 192 'msort': [torch.float16], 193 } 194 195 ON_MPS_XFAILLIST = { 196 # Failures due to lack of implementation of downstream functions on MPS backend 197 # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented 198 'linalg.matrix_rank': None, 199 200 # Exception: Caused by sample input at index 3 on MPS 201 'nn.functional.conv3d': [torch.float32], 202 } 203 204 def addDecorator(op, d) -> None: 205 op.decorators = list(op.decorators) if op.decorators is not None else [] 206 op.decorators.append(d) 207 208 for op in ops: 209 key = op.name + op.variant_test_name 210 if key in XFAILLIST_GRAD: 211 addDecorator(op, DecorateInfo( 212 unittest.expectedFailure, 213 dtypes=XFAILLIST_GRAD[key])) 214 215 if key in SKIPLIST_GRAD: 216 addDecorator(op, DecorateInfo( 217 unittest.skip, 218 dtypes=SKIPLIST_GRAD[key])) 219 220 if key in ON_MPS_XFAILLIST: 221 addDecorator(op, DecorateInfo( 222 unittest.expectedFailure, 223 dtypes=ON_MPS_XFAILLIST[key])) 224 225 if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()): 226 addDecorator(op, DecorateInfo( 227 unittest.expectedFailure, 228 dtypes=MACOS_12_3_XFAILLIST_GRAD[key])) 229 230 if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): 231 addDecorator(op, DecorateInfo( 232 unittest.expectedFailure, 233 dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key])) 234 235 if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3): 236 addDecorator(op, DecorateInfo( 237 unittest.expectedFailure, 238 dtypes=MACOS_13_3_XFAILLIST_GRAD[key])) 239 yield op 240 241def mps_ops_modifier(ops): 242 # Supported complex OPS 243 SUPPORTED_COMPLEX_OPS = { 244 '__radd__', 245 '__rmul__', 246 '__getitem__', 247 'abs', 248 'add', 249 'alias_copy', 250 'argwhere', 251 'atleast_1d', 252 'atleast_2d', 253 'atleast_3d', 254 'as_strided', 255 'as_strided_copy', 256 'as_strided_scatter', 257 'broadcast_tensors', 258 'broadcast_to', 259 'chalf', 260 'cfloat', 261 'chunk', 262 'clone', 263 'conj', 264 'conj_physical', 265 'contiguous', 266 'diag', 267 'diag_embed', 268 'diagflat', 269 'diagonal', 270 'diagonal_copy', 271 'diagonal_scatter', 272 'dsplit', 273 'empty', 274 'empty_permuted', 275 'empty_strided', 276 'eye', 277 'exp', 278 'expand', 279 'expand_as', 280 'expand_copy', 281 'flatten', 282 'fill', 283 'full', 284 'H', 285 'hsplit', 286 'imag', 287 'index_select', 288 'isfinite', 289 'isinf', 290 'isreal', 291 'item', 292 'kron', 293 'linalg.diagonal', 294 'linalg.svd', 295 'linspace', 296 'logspace', 297 'linspacetensor_overload', 298 'logspacetensor_overload', 299 'mH', 300 'mT', 301 'masked_scatter', 302 'masked_select', 303 'meshgridlist_of_tensors', 304 'meshgridvariadic_tensors', 305 'movedim', 306 'mul', 307 'narrow', 308 'narrow_copy', 309 'nn.functional.conv1d', 310 'nn.functional.conv2d', 311 'nn.functional.conv_transpose1d', 312 'nn.functional.conv_transpose2d', 313 'nn.functional.feature_alpha_dropoutwithout_train', 314 'nn.functional.padcircular', 315 'nn.functional.tanhshrink', 316 'nn.functional.unfold', 317 'nonzero', 318 'ones', 319 'outer', 320 'permute', 321 'positive', 322 'randn', 323 'ravel', 324 'real', 325 'repeat_interleave', 326 'reshape_as', 327 'reshape', 328 'resolve_conj', 329 'resolve_neg', 330 'scalar_tensor', 331 'select', 332 'sgn', 333 'slice', 334 'split', 335 'split_with_sizes', 336 'split_with_sizes_copy', 337 'splitlist_args', 338 'squeeze', 339 'squeezemultiple', 340 'sub', 341 'svd', 342 't', 343 't_copy', 344 'tanh', 345 'tensor_split', 346 'transpose', 347 'T', 348 'unbind', 349 'unflatten', 350 'unfold', 351 'unfold_copy', 352 'unsafe_chunk', 353 'unsafe_split', 354 'unsqueeze', 355 'unsqueeze_copy', 356 'view_as', 357 'view_as_real', 358 'view', 359 'view_copy', 360 'vsplit', 361 'zero_', 362 'zeros', 363 } 364 365 AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = { 366 '__rdiv__', 367 '__rmatmul__', 368 '_chunk_cat', 369 '_unsafe_masked_index', 370 'acos', 371 'acosh', 372 'all', 373 'allclose', 374 'any', 375 'addcdiv', 376 'addcmul', 377 'addmmdecomposed', 378 'addmv', 379 'asin', 380 'atan', 381 'atanh', 382 'bfloat16', 383 'bmm', 384 'bool', 385 'cartesian_prod', 386 'cat', 387 'char', 388 'column_stack', 389 'combinations', 390 'corrcoef', 391 'constant_pad_nd', 392 'cos', 393 'cosh', 394 'count_nonzero', 395 'diff', 396 'div', 397 'divno_rounding_mode', 398 'dot', 399 'dstack', 400 'einsum', 401 'eq', 402 'equal', 403 'exp2', 404 'expm1', 405 'fft.fft', 406 'fft.fft2', 407 'fft.fftn', 408 'fft.fftshift', 409 'fft.ifft', 410 'fft.ifft2', 411 'fft.ifftn', 412 'fft.ifftshift', 413 'fft.irfftn', 414 'fft.irfft2', 415 'fft.irfft', 416 'fft.hfftn', 417 'fft.hfft2', 418 'fft.hfft', 419 'flip', 420 'fliplr', 421 'flipud', 422 'float', 423 'gradient', 424 'half', 425 'hstack', 426 'inner', 427 'int', 428 'isclose', 429 'isnan', 430 'ldexp', 431 'linalg.multi_dot', 432 'linalg.pinv', 433 'log10', 434 'log1p', 435 'log2', 436 'log', 437 'logical_and', 438 'logical_not', 439 'logical_or', 440 'logical_xor', 441 'logsumexp', 442 'long', 443 'masked_fill', 444 'masked.mean', 445 'masked.prod', 446 'masked.std', 447 'masked.sum', 448 'masked.var', 449 'masked.logsumexp', 450 'matmul', 451 'mean', 452 'mm', 453 'mv', 454 'ne', 455 'neg', 456 'nn.functional.padconstant', 457 'nn.functional.padreflect', 458 'nn.functional.padreplicate', 459 'nn.functional.pixel_shuffle', 460 'nn.functional.pixel_unshuffle', 461 'nn.functional.rms_norm', 462 'nn.functional.softsign', 463 'pinverse', 464 'prod', 465 'reciprocal', 466 'roll', 467 'rot90', 468 'rsqrt', 469 'short', 470 'sigmoid', 471 'sin', 472 'sinh', 473 'sqrt', 474 'square', 475 'stack', 476 'stft', 477 'sum', 478 'sum_to_size', 479 'tan', 480 'tensordot', 481 'trace', 482 'trapz', 483 'trapezoid', 484 'tril', 485 'triu', 486 'true_divide', 487 'vstack', 488 'where', 489 'byte', 490 } 491 # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758 492 MACOS_12_3_XFAILLIST = { 493 # Top 60 494 # expected failures 495 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. 496 # fixed in macOS 13.3. Currently error is not raised. 497 'pow': [torch.int16, torch.int64, torch.uint8, torch.int8], 498 # expected failures 499 '__rpow__': [torch.uint8, torch.int8], 500 501 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 502 'cdist': [torch.float32], 503 'tan': [torch.uint8, torch.float32], 504 505 # Data type support starts from macOS 13 506 'nn.functional.avg_pool1d': [torch.int64], 507 'nn.functional.avg_pool2d': [torch.int64], 508 'nn.functional.local_response_norm': [torch.int64], 509 '__radd__': [torch.uint8], 510 '__rdiv__': [torch.uint8], 511 '__rmul__': [torch.uint8], 512 'abs': [torch.uint8], 513 'acos': [torch.uint8], 514 'acosh': [torch.uint8], 515 'add': [torch.uint8], 516 'asin': [torch.uint8], 517 'asinh': [torch.uint8], 518 'atan': [torch.uint8], 519 'atanh': [torch.uint8], 520 'ceil': [torch.uint8], 521 'corrcoef': [torch.uint8], 522 'cos': [torch.uint8], 523 'cosh': [torch.uint8], 524 'cov': [torch.uint8], 525 'cumulative_trapezoid': [torch.uint8], 526 'deg2rad': [torch.uint8], 527 'diff': [torch.uint8], 528 'eq': [torch.uint8], 529 'equal': [torch.uint8], 530 'erf': [torch.uint8], 531 'exp2': [torch.uint8], 532 'exp': [torch.uint8], 533 'expm1': [torch.uint8], 534 'floor': [torch.uint8], 535 'fmax': [torch.uint8], 536 'fmin': [torch.uint8], 537 'fmod': [torch.uint8], 538 'ge': [torch.uint8], 539 'gt': [torch.uint8], 540 'isclose': [torch.uint8], 541 'isnan': [torch.uint8], 542 'kron': [torch.uint8], 543 'le': [torch.uint8], 544 'log10': [torch.uint8], 545 'log1p': [torch.uint8], 546 'log2': [torch.uint8], 547 'log': [torch.uint8], 548 'logical_and': [torch.uint8], 549 'logical_or': [torch.uint8], 550 'logical_xor': [torch.uint8], 551 'logit': [torch.uint8], 552 'lt': [torch.uint8], 553 'masked.mean': [torch.uint8], 554 'masked.std': [torch.uint8], 555 'masked.var': [torch.uint8], 556 'maximum': [torch.uint8], 557 'minimum': [torch.uint8], 558 'mul': [torch.uint8], 559 'ne': [torch.uint8], 560 'neg': [torch.uint8], 561 'nn.functional.cosine_embedding_loss': [torch.uint8], 562 'nn.functional.margin_ranking_loss': [torch.uint8], 563 'nn.functional.poisson_nll_loss': [torch.uint8], 564 'nn.functional.softsign': [torch.uint8], 565 'nn.functional.tanhshrink': [torch.uint8], 566 'nn.functional.triplet_margin_loss': [torch.uint8], 567 'nn.functional.triplet_margin_with_distance_loss': [torch.uint8], 568 'nn.functional.pairwise_distance': [torch.uint8], 569 'outer': [torch.uint8], 570 'rad2deg': [torch.uint8], 571 'reciprocal': [torch.uint8], 572 'remainder': [torch.uint8], 573 'round': [torch.uint8], 574 'rsqrt': [torch.uint8], 575 'sigmoid': [torch.uint8], 576 'sign': [torch.uint8], 577 'signbit': [torch.uint8], 578 'sin': [torch.uint8], 579 'sinh': [torch.uint8], 580 'special.ndtr': [torch.uint8], 581 'sqrt': [torch.uint8], 582 'sub': [torch.uint8], 583 'trapezoid': [torch.uint8], 584 'trapz': [torch.uint8], 585 'true_divide': [torch.uint8], 586 'trunc': [torch.uint8], 587 'xlogy': [torch.uint8], 588 'minbinary': [torch.uint8], 589 'maxbinary': [torch.uint8], 590 'divtrunc_rounding': [torch.uint8], 591 'divfloor_rounding': [torch.uint8], 592 'divno_rounding_mode': [torch.uint8], 593 'floor_divide': [torch.uint8], 594 'ldexp': [torch.uint8], 595 # square internally calls into power, and will type cast to int64, which supports starting from macOS 13 596 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 597 598 # cpu not giving nan for x/0.0 599 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 600 601 # inconsistency errors between cpu and mps, max seen atol is 2 602 'nn.functional.interpolatebilinear': [torch.uint8], 603 } 604 605 MACOS_BEFORE_13_3_XFAILLIST = { 606 # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+ 607 'tan': [torch.float32], 608 'cdist': [torch.float32], 609 610 # CPU Error: cpu not giving nan for x/0.0 611 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 612 613 # test blow pass on macOS 12 as it falls back to cpu 614 # Argsort case using duplicate indices (undefined behaviour): 615 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') 616 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') 617 # Elements from index 30 and 5133 are both equal. 618 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. 619 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool], 620 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. 621 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. 622 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], 623 # Unsupported dtypes 624 'cumsum': [torch.int64], 625 'cumprod': [torch.int64], 626 'cumulative_trapezoid': [torch.int64], 627 'masked.cumsum': [torch.int64], 628 'masked.cumprod': [torch.int64], 629 'linalg.vander': [torch.int64], 630 } 631 632 MACOS_AFTER_13_1_XFAILLIST = { 633 # before macOS 13.2 it falls back to cpu and pass the forward pass 634 'grid_sampler_2d': [torch.float32], # Unsupported Border padding mode 635 # inconsistency errors between cpu and mps, max seen atol is 2 636 'nn.functional.interpolatebilinear': [torch.uint8], 637 } 638 639 MACOS_13_3_XFAILLIST = { 640 # Failure due to precision issue for fp16 641 # on both cpu and mps there are test cases that might produce inf result 642 # 'nn.functional.pairwise_distance': [torch.float16], 643 644 # test blow pass on macOS 12 as it falls back to cpu 645 # Argsort case using duplicate indices (undefined behaviour): 646 # - CPU output: tensor([2546, 6917, 3181, ..., 7128, 5133, 30], devuce='cpu') 647 # - MPS output: tensor([2546, 6917, 3181, ..., 7128, 30, 5133], device='mps:0') 648 # Elements from index 30 and 5133 are both equal. 649 # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour. 650 'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool], 651 # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices. 652 # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour. 653 'sort': [torch.int8, torch.uint8, torch.bool, torch.float16], 654 } 655 656 MACOS_BEFORE_14_4_XFAILLIST = { 657 # These ops work fine in 14.4 but fail in 14.2 or 13.x 658 'fft.hfft2': [torch.complex64], 659 } 660 661 # Those ops are not expected to work 662 UNIMPLEMENTED_XFAILLIST = { 663 # Failures due to lack of op implementation on MPS backend 664 'login': None, 665 'linalg.eig': None, 666 'linalg.eigvals': None, 667 'put': None, 668 'nn.functional.conv_transpose3d': None, 669 'rounddecimals_neg_3': None, 670 'rounddecimals_3': None, 671 'rounddecimals_0': None, 672 '__rsub__': None, 673 'angle': None, 674 'cauchy_': None, 675 'cauchy': None, 676 'cholesky': None, 677 'cholesky_inverse': None, 678 'cholesky_solve': None, 679 'cummax': None, 680 'cummin': None, 681 'erfc': None, 682 'frexp': None, 683 'gcd': None, 684 'geqrf': None, 685 'nn.functional.grid_sample': None, # Unsupported Border padding mode 686 'heaviside': None, 687 'i0': None, 688 'igamma': None, 689 'igammac': None, 690 'index_copy': None, 691 'index_reduceprod': None, 692 'index_reducemean': None, 693 'index_reduceamax': None, 694 'index_reduceamin': None, 695 'isneginf': None, 696 'isposinf': None, 697 'kthvalue': None, 698 'lcm': None, 699 'linalg.cholesky': None, 700 'linalg.cholesky_ex': None, 701 'linalg.cond': None, 702 'linalg.detsingular': None, 703 'linalg.det': None, 704 'linalg.eigh': None, 705 'linalg.eigvalsh': None, 706 'linalg.householder_product': None, 707 'linalg.ldl_factor': None, 708 'linalg.ldl_factor_ex': None, 709 'linalg.ldl_solve': None, 710 'linalg.lstsq': None, 711 'linalg.lstsqgrad_oriented': None, 712 'linalg.lu': None, 713 'linalg.lu_factor_ex': None, 714 'linalg.lu_solve': None, 715 'linalg.matrix_norm': [torch.float32], 716 'linalg.norm': [torch.float32], 717 'linalg.normsubgradients_at_zero': [torch.float32], 718 'linalg.qr': None, 719 'linalg.slogdet': None, 720 'linalg.solve': None, 721 'linalg.solve_ex': None, 722 'linalg.svdvals': None, 723 'linalg.tensorsolve': None, 724 'linalg.vecdot': None, 725 'logcumsumexp': None, 726 'logdet': None, 727 'lu': None, 728 'lu_solve': None, 729 'lu_unpack': None, 730 'masked.median': None, 731 'matrix_exp': None, 732 'mode': None, 733 'nanmedian': None, 734 'native_dropout_backward': None, 735 'normnuc': None, 736 'nn.functional.fractional_max_pool2d': None, 737 'nn.functional.fractional_max_pool3d': None, 738 'nn.functional.adaptive_avg_pool3d': None, 739 'nn.functional.adaptive_max_pool3d': None, 740 'nn.functional.interpolatearea': None, 741 'nn.functional.interpolatebicubic': None, 742 'nn.functional.interpolatetrilinear': None, 743 'nn.functional.max_unpool1dgrad': None, 744 'nn.functional.max_unpool2dgrad': None, 745 'nn.functional.max_unpool3dgrad': None, 746 'nn.functional.avg_pool3d': None, 747 'nn.functional.ctc_loss': None, 748 'nn.functional.embedding_bag': None, 749 'nn.functional.hardshrink': None, 750 'nn.functional.max_pool3d': None, 751 'nn.functional.max_unpool1d': None, 752 'nn.functional.max_unpool2d': None, 753 'nn.functional.max_unpool3d': None, 754 'nn.functional.multi_margin_loss': None, 755 'nn.functional.multilabel_margin_loss': None, 756 'nn.functional.pdist': None, 757 'nn.functional.rrelu': None, 758 'nn.functional.norm': None, 759 'ormqr': None, 760 'pca_lowrank': None, 761 'qr': None, 762 'rsub': None, 763 'scatter_reduceamax': None, 764 'scatter_reduceamin': None, 765 'scatter_reducemin': None, 766 'scatter_reducemean': None, 767 'scatter_reduceprod': None, 768 'scatter_reducesum': None, 769 'segment_reduce': None, 770 '_segment.reduce': None, 771 'segment.reduce': None, 772 'segment_reduce_offsets': None, 773 '_segment_reduce_offsets': None, 774 '_segment_reduce_lengths': None, 775 '_segment_reducelengths': None, 776 '_segment_reduceoffsets': None, 777 'sinc': None, 778 'sparse.mm': None, 779 'sparse.mmreduce': None, 780 'special.airy_ai': None, 781 'special.bessel_j0': None, 782 'special.bessel_j1': None, 783 'special.bessel_y0': None, 784 'special.bessel_y1': None, 785 'special.chebyshev_polynomial_t': None, 786 'special.chebyshev_polynomial_u': None, 787 'special.entr': None, 788 'special.erfcx': None, 789 'special.hermite_polynomial_h': None, 790 'special.hermite_polynomial_he': None, 791 'special.i0e': None, 792 'special.i1': None, 793 'special.i1e': None, 794 'special.laguerre_polynomial_l': None, 795 'special.log_ndtr': None, 796 'special.modified_bessel_i0': None, 797 'special.modified_bessel_i1': None, 798 'special.modified_bessel_k0': None, 799 'special.modified_bessel_k1': None, 800 'special.ndtri': None, 801 'special.scaled_modified_bessel_k0': None, 802 'special.scaled_modified_bessel_k1': None, 803 'special.spherical_bessel_j0': None, 804 'special.xlog1py': None, 805 'special.zeta': None, 806 'svd_lowrank': None, 807 'symeig': None, 808 'take': None, 809 'to': None, 810 'to_sparse': None, 811 'unique': None, 812 'vdot': None, 813 'segment_reduce_': None, 814 '_upsample_bilinear2d_aa': None, 815 'geometric' : None, 816 'geometric_': None, 817 'log_normal_': None, 818 'log_normal': None, 819 'cdouble': None, 820 'double': None, 821 'nn.functional.softminwith_dtype': None, 822 'log_softmaxwith_dtype': None, 823 'softmaxwith_dtype': None, 824 'float_power': None, 825 'full_like': None, 826 'linalg.matrix_rankhermitian': None, 827 'linalg.pinvhermitian': None, 828 'nonzero_static': None, 829 830 # MPS: input sizes must be divisible by output sizes 831 'nn.functional.adaptive_avg_pool1d': None, 832 'nn.functional.adaptive_avg_pool2d': None, 833 834 # Unsupported dtypes 835 # bmm is not supported for integral types 836 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 837 'ones_like': None, 838 'zeros_like': None, 839 840 # Convolution for integral types is not supported on MPS 841 'nn.functional.conv1d': [torch.int64], 842 'nn.functional.conv2d': [torch.int64], 843 'nn.functional.conv3d': [torch.int64], 844 'nn.functional.conv_transpose1d': [torch.int64], 845 'nn.functional.conv_transpose2d': [torch.int64], 846 847 # Unsupported dtypes 848 'dot': [torch.int64], 849 'histc': [torch.float16], 850 'index_add': [torch.int64], 851 'log1p': [torch.int64], 852 'sigmoid': [torch.int64], 853 'atan2': [torch.int64], 854 855 # GEMM on MPS is not supported for integral types 856 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 857 '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 858 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 859 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 860 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 861 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 862 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 863 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 864 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 865 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 866 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 867 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 868 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 869 'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 870 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 871 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 872 'unravel_index': [torch.int32, torch.int64], 873 874 # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as 875 # the MPS framework doesn't support float64 876 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 877 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 878 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 879 # returned output on CPU is float64 880 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 881 882 # trunc_tensor not working properly for float16 883 'divtrunc_rounding': [torch.float16], 884 'fmod': [torch.float16], 885 886 # round not working properly for float16 887 'round': [torch.float16], 888 889 # atomic operations not supported 890 '_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64], 891 } 892 893 if product_version < 14.0: 894 # FFT and BFloat16 support was added in MacOS 14 895 UNIMPLEMENTED_XFAILLIST.update({ 896 'bfloat16': None, 897 'fft.fft': None, 898 'fft.fft2': None, 899 'fft.fftn': None, 900 'fft.hfft': None, 901 'fft.hfft2': None, 902 'fft.hfftn': None, 903 'fft.ifft': None, 904 'fft.ifft2': None, 905 'fft.ifftn': None, 906 'fft.ihfft': None, 907 'fft.ihfft2': None, 908 'fft.ihfftn': None, 909 'fft.irfft': None, 910 'fft.irfft2': None, 911 'fft.irfftn': None, 912 'fft.rfft': None, 913 'fft.rfft2': None, 914 'fft.rfftn': None, 915 'stft': None, 916 # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers, 917 # not reproducible in later OS. Added assert to op if used in < 14.0 918 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8], 919 'nn.functional.max_pool2d': [torch.uint8], 920 }) 921 922 if product_version < 15.0: 923 UNIMPLEMENTED_XFAILLIST.update({ 924 'quantile': None, 925 'nanquantile': None, 926 }) 927 928 UNDEFINED_XFAILLIST = { 929 # Top 60 operators 930 # topk fails with duplicate indices 931 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 932 933 # Failures due to random output that they generate using 934 # Philox engine causing mismatch with CPU results 935 'multinomial': [torch.float16, torch.float32], # random results 936 'uniform': [torch.float16, torch.float32], 937 'rand_like': [torch.float16, torch.float32], 938 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 939 'randn_like': [torch.float16, torch.float32], 940 'bernoulli': [torch.float16, torch.float32], 941 'exponential': [torch.float16, torch.float32], 942 'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32], 943 'normal': [torch.float16, torch.float32, torch.float16, torch.float32], 944 'normalin_place': [torch.float16, torch.float32], 945 'normalnumber_mean': [torch.float16, torch.float32], 946 'nn.functional.alpha_dropout': [torch.float16, torch.float32], 947 'nn.functional.dropout': [torch.float16, torch.float32], 948 'nn.functional.dropout2d': [torch.float16, torch.float32], 949 'nn.functional.dropout3d': [torch.float16, torch.float32], 950 # See https://github.com/pytorch/pytorch/issues/111479 951 'nn.functional.multi_head_attention_forward': [torch.float32, torch.float16], 952 953 # duplicate indices are used in the testcase - undefined behaviour 954 'index_put': None, 955 # zero to negative integer powers are undefined 956 '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64], 957 'resize_': [torch.float16, torch.float32], 958 'resize_as_': [torch.float16, torch.float32], 959 960 # CPU Errors: 961 'addr': [torch.bool, torch.int16, torch.int32, 962 torch.int64, torch.uint8, torch.int8], # "addmv_impl_cpu" not implemented for 'Half' 963 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16, 964 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values 965 'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16, 966 torch.int32, torch.int64, torch.uint8, torch.int8], # cpu result off, showing random values 967 968 # random results 969 # mps vs cpu: 970 # Mismatched elements: 40 / 96 (41.7%) 971 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) 972 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) 973 # cuda(2.0.0.dev20230301+cu117) vs cpu: 974 # Mismatched elements: 56 / 96 (58.3%) 975 # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed) 976 # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed) 977 'nn.functional.scaled_dot_product_attention': [torch.float32, torch.float16], 978 979 # float output for float16 input on MPS 980 'logit': [torch.float16], 981 } 982 983 ON_MPS_XFAILLIST = { 984 # Failures due to lack of implementation of downstream functions on MPS backend 985 # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented 986 'linalg.matrix_rank': None, 987 } 988 989 EMPTY_OPS_SKIPLIST = { 990 # Fill tensors with uninitialized data, causing mismatch with CPU. 991 # They occasionally match, thus skipping them. 992 # See https://github.com/pytorch/pytorch/issues/100175 993 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 994 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, 995 torch.int32, torch.int64, torch.uint8, torch.int8], 996 'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 997 # CPU: empty is returning all 0's and there is a mismatch with MPS 998 # allocation (MacOS 13). According to 999 # https://pytorch.org/docs/2.0/generated/torch.empty.html 1000 'empty': [torch.bool, torch.float16, torch.float32, torch.int16, 1001 torch.int32, torch.int64, torch.uint8, torch.int8], 1002 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 1003 'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16, 1004 torch.int32, torch.int64, torch.uint8, torch.int8], 1005 } 1006 1007 SKIPLIST = { 1008 # Unsupported 1009 # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible 1010 'nn.functional.avg_pool2d': [torch.float16], 1011 1012 # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16 1013 'nn.functional.conv3d': None, 1014 } 1015 1016 def addDecorator(op, d) -> None: 1017 op.decorators = list(op.decorators) if op.decorators is not None else [] 1018 op.decorators.append(d) 1019 1020 for op in ops: 1021 key = op.name + op.variant_test_name 1022 if key in EMPTY_OPS_SKIPLIST: 1023 addDecorator(op, DecorateInfo( 1024 unittest.skip("Skipping empty ops."), 1025 dtypes=EMPTY_OPS_SKIPLIST[key])) 1026 if key in SKIPLIST: 1027 addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key])) 1028 for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]: 1029 if key in xfaillist: 1030 addDecorator(op, DecorateInfo( 1031 unittest.expectedFailure, 1032 dtypes=xfaillist[key])) 1033 1034 if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4): 1035 addDecorator(op, DecorateInfo( 1036 unittest.expectedFailure, 1037 dtypes=MACOS_BEFORE_14_4_XFAILLIST[key])) 1038 1039 if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3): 1040 addDecorator(op, DecorateInfo( 1041 unittest.expectedFailure, 1042 dtypes=MACOS_BEFORE_13_3_XFAILLIST[key])) 1043 1044 if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2): 1045 addDecorator(op, DecorateInfo( 1046 unittest.expectedFailure, 1047 dtypes=MACOS_AFTER_13_1_XFAILLIST[key])) 1048 1049 if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3): 1050 addDecorator(op, DecorateInfo( 1051 unittest.expectedFailure, 1052 dtypes=MACOS_13_3_XFAILLIST[key])) 1053 1054 if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()): 1055 addDecorator(op, DecorateInfo( 1056 unittest.expectedFailure, 1057 dtypes=MACOS_12_3_XFAILLIST[key])) 1058 1059 # If ops is not supported for complex types, expect it to fail 1060 if key not in SUPPORTED_COMPLEX_OPS and (key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS or product_version < 14.0): 1061 addDecorator(op, DecorateInfo(unittest.expectedFailure, dtypes=[torch.complex32, torch.complex64])) 1062 1063 yield op 1064 1065def mps_ops_error_inputs_modifier(ops): 1066 # Error input samples do not take a dtype argument. 1067 XFAILLIST = { 1068 # Exceptions are not raised 1069 '__rmod__', 1070 '__rsub__', 1071 '__rpow__', 1072 'bernoulli', 1073 'clamp_max', 1074 'clamp_min', 1075 'masked_scatter', 1076 1077 # unsupported float64 dtype 1078 'cat', 1079 'complex', 1080 'multinomial', 1081 'nn.functional.conv1d', 1082 'nn.functional.conv2d', 1083 'nn.functional.conv3d', 1084 'gather', 1085 'scatter', 1086 'scatter_add', 1087 1088 # unsupported complex dtypes 1089 'masked_fill', 1090 1091 # MPS does not support tensor dimensions > 16 1092 'amax', 1093 'amin', 1094 'aminmax', 1095 1096 # memory overlapping checks 1097 'index_select', 1098 1099 # unimplemented 1100 'logcumsumexp', 1101 } 1102 1103 def addDecorator(op, d) -> None: 1104 op.decorators = list(op.decorators) if op.decorators is not None else [] 1105 op.decorators.append(d) 1106 1107 for op in ops: 1108 if op.error_inputs_func is None: 1109 continue 1110 key = op.name + op.variant_test_name 1111 if key in XFAILLIST: 1112 addDecorator(op, DecorateInfo(unittest.expectedFailure)) 1113 yield op 1114 1115# Same logic as test_cuda.py 1116if not torch.backends.mps.is_available(): 1117 print('MPS not available, skipping tests', file=sys.stderr) 1118 TestCase = NoTest # noqa: F811 1119 NNTestCase = NoTest # noqa: F811 1120 1121product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1) 1122total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"])) 1123 1124# Determine whether to enable MPS memory leak check (uses same code as CUDA). 1125TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1' 1126 1127def skipMPSMemoryLeakCheckIf(condition): 1128 def dec(fn): 1129 if getattr(fn, '_do_mps_memory_leak_check', True): 1130 fn._do_mps_memory_leak_check = not condition 1131 return fn 1132 return dec 1133 1134class MpsMemoryLeakCheck: 1135 def __init__(self, testcase, name=None): 1136 self.name = testcase.id() if name is None else name 1137 self.testcase = testcase 1138 1139 def __enter__(self): 1140 # Performs a gc if required (required if any memory is held) 1141 caching_allocator_mem_allocated = torch.mps.current_allocated_memory() 1142 if caching_allocator_mem_allocated > 0: 1143 gc.collect() 1144 torch.mps.empty_cache() 1145 1146 # Acquires caching allocator and driver statistics before the test is run 1147 self.caching_allocator_before = torch.mps.current_allocated_memory() 1148 self.driver_before = torch.mps.driver_allocated_memory() 1149 1150 def __exit__(self, exec_type, exec_value, traceback): 1151 # Don't check for leaks if an exception was thrown 1152 if exec_type is not None: 1153 return 1154 # Compares caching allocator before/after statistics 1155 # An increase in allocated memory is a discrepancy indicating a possible memory leak 1156 discrepancy_detected = False 1157 caching_allocator_mem_allocated = torch.mps.current_allocated_memory() 1158 if caching_allocator_mem_allocated > self.caching_allocator_before: 1159 discrepancy_detected = True 1160 1161 # Short-circuits if no discrepancy detected 1162 if not discrepancy_detected: 1163 return 1164 # Validates the discrepancy persists after garbage collection and 1165 # is confirmed by the driver API 1166 gc.collect() 1167 torch.mps.empty_cache() 1168 1169 discrepancy_detected = True 1170 # Query memory multiple items to ensure leak was not transient 1171 for n in range(3): 1172 caching_allocator_mem_allocated = torch.mps.current_allocated_memory() 1173 driver_mem_allocated = torch.mps.driver_allocated_memory() 1174 1175 caching_allocator_discrepancy = False 1176 driver_discrepancy = False 1177 1178 if caching_allocator_mem_allocated > self.caching_allocator_before: 1179 caching_allocator_discrepancy = True 1180 1181 if driver_mem_allocated > self.driver_before: 1182 driver_discrepancy = True 1183 1184 if not (caching_allocator_discrepancy or driver_discrepancy): 1185 # Leak was false positive, exit loop 1186 discrepancy_detected = False 1187 break 1188 1189 if caching_allocator_discrepancy and not driver_discrepancy: 1190 # Just raises a warning if the leak is not validated by the driver API 1191 msg = ("MPS caching allocator reports a memory leak not " 1192 f"verified by the driver API in {self.name}! " 1193 f"Caching allocator allocated memory was {self.caching_allocator_before} " 1194 f"and is now reported as {caching_allocator_mem_allocated}. " 1195 f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.") 1196 warnings.warn(msg) 1197 elif caching_allocator_discrepancy and driver_discrepancy: 1198 # A caching allocator discrepancy validated by the driver API is a failure 1199 msg = (f"MPS driver API confirmed a leak in {self.name}! " 1200 f"Caching allocator allocated memory was {self.caching_allocator_before} " 1201 f"and is now reported as {caching_allocator_mem_allocated}. " 1202 f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.") 1203 1204 raise RuntimeError(msg) 1205 1206class TestAutocastMPS(TestCase): 1207 1208 def test_matmul_autocast(self): 1209 autocast_tensor_A = torch.rand((8, 8), device="mps") 1210 autocast_tensor_B = torch.rand((8, 8), device="mps") 1211 tensor_A = autocast_tensor_A.clone().detach() 1212 tensor_B = autocast_tensor_B.clone().detach() 1213 autocast_output_tensor = torch.empty(8, 8) 1214 output_tensor = autocast_output_tensor.clone().detach() 1215 1216 with torch.autocast(device_type="mps"): 1217 autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B) 1218 autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor) 1219 1220 output_tensor = torch.mm(tensor_A, tensor_B) 1221 output_tensor = torch.mm(tensor_A, output_tensor) 1222 1223 self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16") 1224 self.assertEqual(autocast_output_tensor, 1225 output_tensor.to(torch.float16), 1226 f"Autocast & non-autocast tensors did not match, \ 1227 got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}") 1228 1229# Expand TestCase class with Memory Leak Detection on MPS device 1230class TestCaseMPS(TestCase): 1231 _do_mps_memory_leak_check = True 1232 1233 def __init__(self, method_name='runTest'): 1234 super().__init__(method_name) 1235 test_method = getattr(self, method_name, None) 1236 if test_method is not None: 1237 # Wraps the tested method if we should do MPS memory check. 1238 if TEST_MPS_MEM_LEAK_CHECK: 1239 if self._do_mps_memory_leak_check: 1240 self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors) 1241 1242 def assertLeaksNoMpsTensors(self, name=None): 1243 name = self.id() if name is None else name 1244 return MpsMemoryLeakCheck(self, name) 1245 1246 def wrap_with_mps_policy(self, method_name, policy): 1247 test_method = getattr(self, method_name) 1248 setattr(self, method_name, super().wrap_method_with_policy(test_method, policy)) 1249 1250 # checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0 1251 def wrap_with_mps_memory_check(self, method): 1252 return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors) 1253 1254class TestMemoryLeak(TestCaseMPS): 1255 def test_mps_memory_leak_detection(self): 1256 l = [] 1257 1258 @self.wrap_with_mps_memory_check 1259 def no_leak(): 1260 pass 1261 1262 # Trigger an intentional memory leak 1263 @self.wrap_with_mps_memory_check 1264 def leak_gpu0(): 1265 # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms 1266 l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps"))) 1267 1268 no_leak() 1269 1270 # check if a runtime error for memory leak was emitted which would 1271 # confirm whether memory leak detection worked successfully or not. 1272 with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"): 1273 leak_gpu0() 1274 1275 def test_copy_cast_no_leak(self): 1276 1277 def step(x): 1278 x = x.to(device='cpu', dtype=torch.float32) 1279 x = x.to(device='mps', dtype=torch.float16) 1280 1281 a = torch.randn(128, 128, device='mps', dtype=torch.float16) 1282 # Warm up / prebuild MPS shaders (otherwise check fails on 13.2) 1283 step(a) 1284 torch.mps.empty_cache() 1285 driver_before = torch.mps.driver_allocated_memory() 1286 step(a) 1287 torch.mps.empty_cache() 1288 driver_after = torch.mps.driver_allocated_memory() 1289 self.assertEqual(driver_before, driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory") 1290 1291 1292class TestPixelShuffle(TestCaseMPS): 1293 def test_pixel_shuffle_unshuffle(self): 1294 def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True, 1295 upscale_factor=None, is_contiguous=True): 1296 1297 def generate_input(): 1298 # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2. 1299 channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1) 1300 height = random.randint(5, 10) 1301 width = random.randint(5, 10) 1302 1303 if num_input_dims == 1: 1304 input = torch.rand(channels, requires_grad=True, device='mps') 1305 assert is_contiguous 1306 elif num_input_dims == 2: 1307 input = torch.rand(width, height, requires_grad=True, device='mps').T 1308 if is_contiguous: 1309 input = input.contiguous() 1310 else: 1311 batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] 1312 input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps') 1313 input = input.transpose(-1, -2) 1314 if is_contiguous: 1315 input = input.contiguous() 1316 1317 if not is_contiguous and len(input.reshape(-1)) > 0: 1318 assert not input.is_contiguous() 1319 1320 input = input.detach().clone() 1321 input.requires_grad = True 1322 return input 1323 1324 # Function to imperatively ensure pixels are shuffled to the correct locations. 1325 # Used to validate the batch operations in pixel_shuffle. 1326 def _verify_pixel_shuffle(input, output, upscale_factor): 1327 for c in range(output.size(-3)): 1328 for h in range(output.size(-2)): 1329 for w in range(output.size(-1)): 1330 height_idx = h // upscale_factor 1331 weight_idx = w // upscale_factor 1332 channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \ 1333 (c * upscale_factor ** 2) 1334 self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx]) 1335 1336 upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor 1337 input = generate_input() 1338 1339 ps = nn.PixelShuffle(upscale_factor) 1340 pus = nn.PixelUnshuffle(downscale_factor=upscale_factor) 1341 1342 if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0: 1343 output = ps(input) 1344 _verify_pixel_shuffle(input, output, upscale_factor) 1345 output.backward(output.data) 1346 self.assertEqual(input.data, input.grad.data) 1347 1348 # Ensure unshuffle properly inverts shuffle. 1349 unshuffle_output = pus(output) 1350 self.assertEqual(input, unshuffle_output) 1351 else: 1352 self.assertRaises(RuntimeError, lambda: ps(input)) 1353 1354 def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True, 1355 downscale_factor=None): 1356 downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor 1357 channels = random.randint(1, 4) 1358 # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor. 1359 height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1) 1360 # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor. 1361 width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1) 1362 1363 if num_input_dims == 1: 1364 input = torch.rand(channels, requires_grad=True, device='mps') 1365 elif num_input_dims == 2: 1366 input = torch.rand(height, width, requires_grad=True, device='mps') 1367 else: 1368 batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)] 1369 input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps') 1370 1371 pus = nn.PixelUnshuffle(downscale_factor) 1372 self.assertRaises(RuntimeError, lambda: pus(input)) 1373 1374 def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims): 1375 # For 1D - 2D, this is an error case. 1376 # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle. 1377 is_contiguous_check = [True, False] if num_input_dims > 1 else [True] 1378 for is_contiguous in is_contiguous_check: 1379 _test_pixel_shuffle_unshuffle_helper( 1380 num_input_dims=num_input_dims, is_contiguous=is_contiguous 1381 ) 1382 _test_pixel_shuffle_unshuffle_helper( 1383 num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous 1384 ) 1385 _test_pixel_shuffle_unshuffle_helper( 1386 num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous 1387 ) 1388 _test_pixel_shuffle_unshuffle_helper( 1389 num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous 1390 ) 1391 1392 # Error cases for pixel_unshuffle. 1393 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False) 1394 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False) 1395 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0) 1396 _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2) 1397 1398 def test_pixel_shuffle_unshuffle_1D(): 1399 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1) 1400 1401 def test_pixel_shuffle_unshuffle_2D(): 1402 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2) 1403 1404 def test_pixel_shuffle_unshuffle_3D(): 1405 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3) 1406 1407 def test_pixel_shuffle_unshuffle_4D(): 1408 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4) 1409 1410 def test_pixel_shuffle_unshuffle_5D(): 1411 _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5) 1412 1413 test_pixel_shuffle_unshuffle_1D() 1414 test_pixel_shuffle_unshuffle_2D() 1415 test_pixel_shuffle_unshuffle_3D() 1416 test_pixel_shuffle_unshuffle_4D() 1417 test_pixel_shuffle_unshuffle_5D() 1418 1419class MPSReluTest(TestCaseMPS): 1420 def _npRelu(self, np_features): 1421 return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype) 1422 1423 def testNpRelu(self): 1424 torch.testing.assert_close( 1425 np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]), 1426 self._npRelu( 1427 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 1428 0.9]]))) 1429 1430 def _testRelu(self, np_features, device): 1431 np_relu = self._npRelu(np_features) 1432 # Convert the numpy array to a PyTorch Tensor, 1433 # and move the Tensor to the CPU/GPU based on the "device" parameter 1434 py_tensor = torch.from_numpy(np_features).to(device) 1435 py_relu = torch.nn.ReLU(inplace=False)(py_tensor) 1436 py_relu_cpu = py_relu.to("cpu") 1437 1438 self.assertEqual(np_relu, py_relu_cpu) 1439 1440 def _testReluInPlace(self, np_features, device): 1441 np_relu = self._npRelu(np_features) 1442 # Convert the numpy array to a PyTorch Tensor, 1443 # and move the Tensor to the CPU/GPU based on the "device" parameter 1444 py_tensor = torch.from_numpy(np_features).to(device) 1445 py_relu = torch.nn.ReLU(inplace=True)(py_tensor) 1446 py_relu_cpu = py_relu.to("cpu") 1447 1448 self.assertEqual(np_relu, py_relu_cpu) 1449 # Inplace Relu modifies the initial input and it should match the output of Relu 1450 self.assertEqual(np_relu, py_tensor.to("cpu")) 1451 1452 def testNumbersCPU(self): 1453 for t in [np.int32]: 1454 # Force execution on CPU even if a GPU kernel is available for the type. 1455 self._testRelu( 1456 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1457 device="cpu") 1458 self._testReluInPlace( 1459 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1460 device="cpu") 1461 1462 def testNumbersGPU(self): 1463 for t in [np.float16, np.float32]: 1464 self._testRelu( 1465 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1466 device="mps") 1467 self._testReluInPlace( 1468 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1469 device="mps") 1470 self._testRelu(np.array([]).astype(t), device="mps") 1471 self._testReluInPlace(np.array([]).astype(t), device="mps") 1472 1473class MatmulTest(TestCaseMPS): 1474 def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None): 1475 if expand_tensor_1_shape: 1476 tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape) 1477 else: 1478 tensor1_mps = torch.randn(shape_tensor_1, device="mps") 1479 1480 if expand_tensor_2_shape: 1481 tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape) 1482 else: 1483 tensor2_mps = torch.randn(shape_tensor_2, device="mps") 1484 1485 tensor1_cpu = tensor1_mps.to("cpu") 1486 tensor2_cpu = tensor2_mps.to("cpu") 1487 1488 matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu) 1489 matmul_mps = torch.matmul(tensor1_mps, tensor2_mps) 1490 1491 self.assertEqual(matmul_cpu, matmul_mps.to("cpu")) 1492 1493 def test_vector_x_vector(self): 1494 # uses `dot` 1495 self._helper(3, 3) 1496 1497 def test_matrix_x_vector(self): 1498 # uses `addmv` 1499 self._helper((3, 4), 4) 1500 1501 def test_batched_matrix_x_broadcasted_vector(self): 1502 self._helper((10, 3, 4), 4) 1503 1504 def test_batched_matrix_x_batched_matrix(self): 1505 # uses `bmm.out` 1506 self._helper((10, 3, 4), (10, 4, 5)) 1507 1508 def test_batched_matrix_x_broadcasted_matrix(self): 1509 self._helper((10, 3, 4), (4, 5)) 1510 1511 1512class MPSLeakyReluTest(TestCaseMPS): 1513 def _npLeakyRelu(self, np_features, negative_slope=0.1): 1514 return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype) 1515 1516 def testNpLeakyRelu(self): 1517 torch.testing.assert_close( 1518 np.array([[-0.09, 0.7, -0.05, 0.3, -0.01], 1519 [0.1, -0.03, 0.5, -0.07, 0.9]]), 1520 self._npLeakyRelu( 1521 np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 1522 0.9]]), 1523 negative_slope=0.1)) 1524 1525 def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous): 1526 cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 1527 mps_x = cpu_x.detach().clone().to('mps') 1528 1529 if not contiguous and not (0 in shape or len(shape) < 2): 1530 # Tranposing will make the tensor non-contiguous 1531 cpu_x = cpu_x.transpose(0, 1) 1532 mps_x = mps_x.transpose(0, 1) 1533 assert not mps_x.is_contiguous() 1534 1535 cpu_x.requires_grad_() 1536 mps_x.requires_grad_() 1537 1538 relu_op = torch.nn.LeakyReLU(negative_slope) 1539 1540 cpu_leaky_relu = relu_op(cpu_x) 1541 mps_leaky_relu = relu_op(mps_x) 1542 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu')) 1543 1544 # test backward pass 1545 1546 cpu_grad = torch.ones_like(cpu_leaky_relu) 1547 mps_grad = cpu_grad.to('mps') 1548 1549 mps_leaky_relu.backward(gradient=mps_grad) 1550 cpu_leaky_relu.backward(gradient=cpu_grad) 1551 1552 assert cpu_x.grad is not None # Check that the grad is well-populated 1553 self.assertEqual(cpu_x.grad, mps_x.grad) 1554 1555 def testNumbersCPU(self): 1556 for t in [torch.float, torch.half]: 1557 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: 1558 for contiguous in [True, False]: 1559 self._testLeakyRelu(shape, 1560 dtype=t, 1561 negative_slope=0.2, 1562 contiguous=contiguous) 1563 1564class TestAvgPool(TestCaseMPS): 1565 def _sum_pool2d(self, x, kernel_size): 1566 windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size) 1567 return torch.sum(windows, dim=1) 1568 1569 def _sum_pool3d(self, x, kernel_size): 1570 # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum 1571 h = kernel_size[0] 1572 splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h] 1573 # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times 1574 splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x] 1575 joined_x = torch.cat(splited_x) 1576 return joined_x.view(1, joined_x.numel()) 1577 1578 def _avg_pool2d(self, x, kernel_size): 1579 size = reduce(operator.mul, kernel_size) # noqa: F821 1580 return self._sum_pool2d(x, kernel_size) / size 1581 1582 def _avg_pool3d(self, x, kernel_size): 1583 size = reduce(operator.mul, kernel_size) # noqa: F821 1584 return self._sum_pool3d(x, kernel_size) / size 1585 1586 def test_avg_pool2d_with_zero_divisor(self): 1587 self.assertRaisesRegex(RuntimeError, "divisor must be not zero", 1588 lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0)) 1589 1590 def test_doubletensor_avg_pool2d_with_divisor(self): 1591 n, m = 3, 3 1592 input = torch.rand(1, 1, n, m) 1593 for i in range(1, n + 1): 1594 for j in range(1, m + 1): 1595 for divisor in [1, 7, i * j]: 1596 actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor) 1597 actual = actual.view(1, actual.numel()) 1598 expected = self._sum_pool2d(input, (i, j)) / divisor 1599 self.assertEqual(actual, expected, rtol=0, atol=1e-5) 1600 1601 def test_avg_pool2d_ceil_mode(self): 1602 # Regression test for gh-36977 1603 x = 10 * torch.randn((1, 16, 4, 4)) 1604 y = torch.nn.functional.avg_pool2d( 1605 x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2), 1606 padding=(0, 1), stride=2) 1607 self.assertFalse(torch.isnan(y).any()) 1608 y = torch.nn.functional.avg_pool2d( 1609 x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2), 1610 padding=(0, 1), stride=2) 1611 self.assertFalse(torch.isnan(y).any()) 1612 1613 1614class TestMPS(TestCaseMPS): 1615 def test_exp(self, device="mps", dtype=torch.float): 1616 for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()): 1617 b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi 1618 a = torch.tensor(v, dtype=dtype, device="mps") * b 1619 self.compare_with_numpy(torch.exp, np.exp, a) 1620 1621 def test_conv_raises_error(self, device='mps', dtype=torch.float): 1622 conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps') 1623 1624 x = torch.ones([1, 1, 3]) 1625 with self.assertRaises(NotImplementedError): 1626 y = conv(x.to("mps")) 1627 1628 def test_triu_inf(self, device="mps", dtype=torch.float): 1629 for diag in [-1, 0, 1]: 1630 mask = torch.full((3, 6, 6), float("-inf")) 1631 mask_mps = mask.clone().detach().to('mps') 1632 cpu_ref = torch.triu(mask, diagonal=diag) 1633 mps_out = torch.triu(mask_mps, diagonal=diag) 1634 self.assertEqual(cpu_ref, mps_out) 1635 1636 def test_exp1(self, device="mps", dtype=torch.float): 1637 input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype) 1638 output = torch.exp(input) 1639 output_cpu = torch.exp(input.cpu()) 1640 # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with 1641 # Mismatched elements: 3 / 4 (75.0%) 1642 # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed) 1643 # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed) 1644 self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8) 1645 1646 def test_exp_strided_output(self): 1647 x = torch.rand((256, 10), device='mps') 1648 x_cpu = x.to("cpu") 1649 1650 x = x.permute(1, 0) 1651 x_cpu = x_cpu.permute(1, 0) 1652 1653 res = x.exp() 1654 res_cpu = x_cpu.exp() 1655 self.assertEqual(res, res_cpu) 1656 1657 def _testLeakyRelu(self, np_features, negative_slope, device): 1658 cpu_x = torch.from_numpy(np_features).requires_grad_() 1659 mps_x = torch.from_numpy(np_features).to('mps').requires_grad_() 1660 relu_op = torch.nn.LeakyReLU(negative_slope) 1661 1662 cpu_leaky_relu = relu_op(cpu_x) 1663 mps_leaky_relu = relu_op(mps_x) 1664 torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu')) 1665 1666 # test backward pass 1667 cpu_grad = torch.ones_like(cpu_leaky_relu) 1668 mps_grad = cpu_grad.to('mps') 1669 cpu_leaky_relu.backward(gradient=cpu_grad) 1670 mps_leaky_relu.backward(gradient=mps_grad) 1671 torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu')) 1672 1673 def testNumbersGPU(self): 1674 for t in [np.float32]: 1675 self._testLeakyRelu( 1676 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t), 1677 negative_slope=0.1, 1678 device="mps") 1679 1680 def test_fill(self): 1681 1682 def helper(val, shape, dtype): 1683 tensor = torch.zeros(shape, device='mps', dtype=dtype) 1684 tensor_mps = tensor.fill_(val) 1685 1686 tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype) 1687 tensor_cpu = tensor_0.fill_(val) 1688 1689 self.assertEqual(tensor_mps, tensor_cpu) 1690 1691 helper(0, [1024], torch.float32) 1692 helper(0.2, [2, 3], torch.float32) 1693 helper(0.2 + 0.5j, [2, 3], torch.complex64) 1694 1695 def test_fill_storage_offset(self): 1696 shape = [2, 10] 1697 val = 0.2 1698 tensor = torch.ones(shape, device="mps") 1699 tensor_mps = tensor[:][1].fill_(val) 1700 tensor_0 = torch.ones(shape, device="cpu") 1701 tensor_cpu = tensor_0[:][1].fill_(val) 1702 1703 self.assertEqual(tensor_mps, tensor_cpu) 1704 self.assertEqual(tensor, tensor_0) 1705 1706 shape = [1, 10] 1707 val = 0.0 1708 tensor = torch.ones(shape, device="mps") 1709 val_tensor_mps = torch.tensor(val, device="mps") 1710 tensor_mps = tensor[:, 9].fill_(val_tensor_mps) 1711 # Regression test for https://github.com/pytorch/pytorch/issues/114692 1712 tensor[:, 5].fill_(val_tensor_mps) 1713 tensor_0 = torch.ones(shape, device="cpu") 1714 val_tensor_cpu = torch.tensor(val, device="cpu") 1715 tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu) 1716 tensor_0[:, 5].fill_(val_tensor_cpu) 1717 1718 self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu) 1719 self.assertEqual(tensor.to(device="cpu"), tensor_0) 1720 1721 def test_cdist_large(self, device="mps"): 1722 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1723 x = torch.randn(100, 10, device=device) 1724 y = torch.randn(100, 10, device=device) 1725 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1726 expected = self._brute_cdist(x, y, p=2) 1727 self.assertEqual(expected, actual) 1728 1729 def test_cdist_large_batch(self, device="mps"): 1730 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1731 x = torch.randn(4, 3, 100, 10, device=device) 1732 y = torch.randn(4, 3, 100, 10, device=device) 1733 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1734 expected = self._brute_cdist(x, y, p=2) 1735 self.assertEqual(expected, actual) 1736 1737 def test_cdist_non_contiguous(self, device="mps"): 1738 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1739 x = torch.randn(5, 7, device=device).mT 1740 y = torch.randn(5, 3, device=device).mT 1741 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1742 expected = self._brute_cdist(x, y, p=2) 1743 self.assertFalse(x.is_contiguous()) 1744 self.assertFalse(y.is_contiguous()) 1745 self.assertEqual(expected, actual) 1746 1747 x = torch.randn(7, 5, device=device) 1748 y = torch.randn(5, 3, device=device).t() 1749 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1750 expected = self._brute_cdist(x, y, p=2) 1751 self.assertTrue(x.is_contiguous()) 1752 self.assertFalse(y.is_contiguous()) 1753 self.assertEqual(expected, actual) 1754 1755 x = torch.randn(5, 7, device=device).t() 1756 y = torch.randn(3, 5, device=device) 1757 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1758 expected = self._brute_cdist(x, y, p=2) 1759 self.assertFalse(x.is_contiguous()) 1760 self.assertTrue(y.is_contiguous()) 1761 self.assertEqual(expected, actual) 1762 1763 def test_cdist_non_contiguous_batch(self, device="mps"): 1764 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1765 x = torch.randn(4, 3, 2, 5, 7, device=device).mT 1766 y = torch.randn(4, 3, 2, 5, 3, device=device).mT 1767 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1768 expected = self._brute_cdist(x, y, p=2) 1769 self.assertFalse(x.is_contiguous()) 1770 self.assertFalse(y.is_contiguous()) 1771 self.assertEqual(expected, actual) 1772 1773 x = torch.randn(7, 2, 7, 5, device=device) 1774 y = torch.randn(7, 2, 5, 3, device=device).mT 1775 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1776 expected = self._brute_cdist(x, y, p=2) 1777 self.assertTrue(x.is_contiguous()) 1778 self.assertFalse(y.is_contiguous()) 1779 self.assertEqual(expected, actual) 1780 1781 x = torch.randn(4, 5, 7, device=device).mT 1782 y = torch.randn(4, 3, 5, device=device) 1783 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1784 expected = self._brute_cdist(x, y, p=2) 1785 self.assertFalse(x.is_contiguous()) 1786 self.assertTrue(y.is_contiguous()) 1787 self.assertEqual(expected, actual) 1788 1789 def test_cdist_euclidean_large(self, device="mps"): 1790 def _test_euclidean_large_cdist(sizex, sizey=None): 1791 if sizey is None: 1792 sizey = sizex 1793 x = torch.randn(sizex, device=device, dtype=torch.float) 1794 y = torch.randn(sizey, device=device, dtype=torch.float) 1795 eps = 1e-6 1796 # to avoid extremum 1797 x = x - (((x - y) < eps).float() * 2 * eps) 1798 x.requires_grad = True 1799 y.requires_grad = True 1800 dist = torch.cdist(x, y, p=2) 1801 # Do a backward pass to check that it is valid for large 1802 # matrices 1803 loss = dist.sum() 1804 loss.backward() 1805 1806 _test_euclidean_large_cdist((2000, 5)) 1807 1808 def test_cdist_same_inputs(self, device="mps"): 1809 # Test to detect issues in cdist gradient calculation 1810 # When the distances are 0 1811 sizex = (1, 27, 32) 1812 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: 1813 x = torch.randn(sizex, device=device, dtype=torch.float) 1814 dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float) 1815 y = x.clone() 1816 eps = 1e-6 1817 x.requires_grad = True 1818 d = torch.cdist(x, y) 1819 d.backward(dist_grad) 1820 # Check that the backward passs does not contain invalid 1821 # values such as nan or inf 1822 assert torch.isfinite(x.grad).all() 1823 1824 1825 def _brute_cdist(self, x, y, p=2): 1826 r1 = x.shape[-2] 1827 r2 = y.shape[-2] 1828 if r1 == 0 or r2 == 0: 1829 return torch.empty(r1, r2, device=x.device) 1830 return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1) 1831 1832 def test_cdist_norm(self, device="mps"): 1833 for r1 in [3, 4]: 1834 for m in [2, 3]: 1835 for r2 in [4, 6]: 1836 for p in [0, 1, 1.5, 2.5, float('inf')]: 1837 x = torch.randn(r1, m, device=device) 1838 y = torch.randn(r2, m, device=device) 1839 if p == 2: 1840 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1841 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1842 expected = self._brute_cdist(x, y, p=2) 1843 self.assertEqual(expected, actual, rtol=0, atol=0.02) 1844 else: 1845 actual = torch.cdist(x, y, p=p) 1846 expected = self._brute_cdist(x, y, p=p) 1847 self.assertEqual(expected, actual) 1848 1849 def test_cdist_norm_batch(self, device="mps"): 1850 for r1 in [3, 4]: 1851 for m in [2, 3]: 1852 for r2 in [4, 6]: 1853 for p in [0, 3, 1.5, 2.5, float('inf')]: 1854 x = torch.randn(2, 3, 6, r1, m, device=device) 1855 y = torch.randn(2, 3, 6, r2, m, device=device) 1856 if p == 2: 1857 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 1858 actual = torch.cdist(x, y, p=2, compute_mode=cm) 1859 expected = self._brute_cdist(x, y, p=2) 1860 self.assertEqual(expected, actual, rtol=0, atol=0.02) 1861 else: 1862 actual = torch.cdist(x, y, p=p) 1863 expected = self._brute_cdist(x, y, p=p) 1864 self.assertEqual(expected, actual) 1865 1866 def test_mm(self): 1867 B = torch.ones(5, 6).to("mps") 1868 C = torch.ones(6, 5).to("mps") 1869 D = torch.mm(B, C).cpu() 1870 torch.testing.assert_close(D, torch.full((5, 5), 6.0)) 1871 1872 def test_linalg_cross(self): 1873 def helper(dtype): 1874 device = "mps" 1875 if dtype is torch.int32 or dtype is torch.int64: 1876 x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device) 1877 y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device) 1878 else: 1879 x = torch.rand(100, 3, 100, dtype=dtype, device=device) 1880 y = torch.rand(100, 3, 100, dtype=dtype, device=device) 1881 x_cpu = x.to("cpu") 1882 y_cpu = y.to("cpu") 1883 res1 = torch.linalg.cross(x, y, dim=1) 1884 res2 = torch.tensor((), dtype=dtype, device=device) 1885 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1) 1886 res2_cpu = torch.tensor((), dtype=dtype, device="cpu") 1887 torch.linalg.cross(x, y, dim=1, out=res2) 1888 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu) 1889 self.assertEqual(res1, res2) 1890 self.assertEqual(res1, res1_cpu) 1891 self.assertEqual(res2, res2_cpu) 1892 1893 # test for broadcastable inputs 1894 if dtype is torch.int32 or dtype is torch.int64: 1895 x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device) 1896 y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device) 1897 else: 1898 x = torch.rand(1, 3, 2, dtype=dtype, device=device) 1899 y = torch.rand(4, 3, 1, dtype=dtype, device=device) 1900 x_cpu = x.to("cpu") 1901 y_cpu = y.to("cpu") 1902 res1 = torch.linalg.cross(x, y, dim=1) 1903 res2 = torch.tensor((), dtype=dtype, device=device) 1904 res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1) 1905 res2_cpu = torch.tensor((), dtype=dtype, device="cpu") 1906 torch.linalg.cross(x, y, dim=1, out=res2) 1907 torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu) 1908 self.assertEqual(res1, res2) 1909 self.assertEqual(res1, res1_cpu) 1910 self.assertEqual(res2, res2_cpu) 1911 [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]] 1912 1913 def test_cross(self): 1914 a = torch.randn(4, 3, device="mps") 1915 b = torch.randn(4, 3, device="mps") 1916 a_cpu = a.to("cpu") 1917 b_cpu = b.to("cpu") 1918 res = torch.cross(a, b, dim=1) 1919 res_cpu = torch.cross(a_cpu, b_cpu, dim=1) 1920 self.assertEqual(res, res_cpu) 1921 1922 def test_addmm(self): 1923 A = torch.ones(5, 5).to("mps") 1924 B = torch.ones(5, 6).to("mps") 1925 C = torch.ones(6, 5).to("mps") 1926 D = torch.addmm(A, B, C).to("cpu") 1927 torch.testing.assert_close(D, torch.full((5, 5), 7.0)) 1928 1929 def test_bmm(self): 1930 batch1_cpu = torch.randn(10, 3, 4) 1931 batch2_cpu = torch.randn(10, 4, 5) 1932 1933 batch1_mps = batch1_cpu.detach().clone().to("mps") 1934 batch2_mps = batch2_cpu.detach().clone().to("mps") 1935 1936 output_cpu = torch.bmm(batch1_cpu, batch2_cpu) 1937 output_mps = torch.bmm(batch1_mps, batch2_mps) 1938 1939 self.assertEqual(output_cpu, output_mps) 1940 self.assertEqual(output_cpu.size(), output_mps.size()) 1941 1942 @xfailIf(product_version < 15.0) 1943 @parametrize("dtype", [torch.float16, torch.bfloat16]) 1944 def test_large_bmm(self, dtype): 1945 batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps') 1946 batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps') 1947 output_cpu = torch.bmm(batch1.cpu(), batch2.cpu()) 1948 output_mps = torch.bmm(batch1, batch2) 1949 1950 # Using the low precision comparison for FP16 1951 tol = 1e-2 if dtype == torch.float16 else None 1952 self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol) 1953 self.assertEqual(output_cpu.size(), output_mps.size()) 1954 1955 1956 def test_addr(self): 1957 A = torch.ones(5, 10).to("mps") 1958 B = torch.ones(5).to("mps") 1959 C = torch.ones(10).to("mps") 1960 D = torch.addr(A, B, C).to("cpu") 1961 torch.testing.assert_close(D, torch.full((5, 10), 2.0)) 1962 1963 def test_trace(self): 1964 M_cpu = torch.randn(3, 3) 1965 M_mps = M_cpu.detach().clone().to("mps") 1966 1967 output_cpu = torch.trace(M_cpu) 1968 output_mps = torch.trace(M_mps) 1969 1970 self.assertEqual(output_cpu, output_mps) 1971 self.assertEqual(output_cpu.size(), output_mps.size()) 1972 1973 def test_addbmm(self): 1974 M_cpu = torch.randn(3, 5) 1975 batch1_cpu = torch.randn(10, 3, 4) 1976 batch2_cpu = torch.randn(10, 4, 5) 1977 1978 M_mps = M_cpu.detach().clone().to("mps") 1979 batch1_mps = batch1_cpu.detach().clone().to("mps") 1980 batch2_mps = batch2_cpu.detach().clone().to("mps") 1981 1982 output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu) 1983 output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps) 1984 1985 self.assertEqual(output_cpu, output_mps) 1986 self.assertEqual(output_cpu.size(), output_mps.size()) 1987 1988 def test_baddbmm(self): 1989 def helper(input_shape, batch1_shape, batch2_shape): 1990 M_cpu = torch.randn(input_shape) 1991 batch1_cpu = torch.randn(batch1_shape) 1992 batch2_cpu = torch.randn(batch2_shape) 1993 alpha = 1.2 1994 beta = 0.8 1995 1996 M_mps = M_cpu.detach().clone().to("mps") 1997 batch1_mps = batch1_cpu.detach().clone().to("mps") 1998 batch2_mps = batch2_cpu.detach().clone().to("mps") 1999 2000 output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha) 2001 output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha) 2002 2003 self.assertEqual(output_cpu, output_mps) 2004 self.assertEqual(output_cpu.size(), output_mps.size()) 2005 2006 helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) 2007 helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5)) 2008 helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77)) 2009 2010 def test_local_scalar_dense_mps(self): 2011 x_cpu = torch.randn(1) 2012 y_mps = x_cpu.to("mps") 2013 torch.testing.assert_close(x_cpu.item(), y_mps.item()) 2014 2015 def test_linear_1d_weight(self): 2016 device = 'cpu' 2017 projected = torch.rand([8]).to(device) 2018 x = torch.rand([1, 2, 2, 8]).to(device) 2019 x_mps = x.to('mps') 2020 projected_mps = projected.to('mps') 2021 linear = F.linear(x, projected) 2022 linear_mps = F.linear(x_mps, projected_mps) 2023 2024 self.assertEqual(linear, linear_mps) 2025 2026 projected = torch.rand([1, 8]).to(device) 2027 x = torch.rand([1, 2, 2, 8]).to(device) 2028 x_mps = x.to('mps') 2029 projected_mps = projected.to('mps') 2030 linear = F.linear(x, projected) 2031 linear_mps = F.linear(x_mps, projected_mps) 2032 2033 self.assertEqual(linear, linear_mps) 2034 2035 def test_linear_bias(self): 2036 def helper(bias_shape): 2037 device = "cpu" 2038 x = torch.randn(2, 2, 2, 64, device=device) 2039 linear = torch.nn.Linear(64, 4, device=device) 2040 linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device)) 2041 y = linear(x) 2042 device = "mps" 2043 x_mps = x.to(device) 2044 linear.to(device) 2045 y_mps = linear(x_mps) 2046 self.assertEqual(y, y_mps) 2047 2048 helper(()) 2049 helper((2, 4)) 2050 2051 def test_linear_errors(self): 2052 # Mixed CPU<->MPS tensors 2053 size = (3, 3) 2054 2055 # Unsupported dtypes 2056 with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"): 2057 torch.nn.functional.linear(torch.rand(size, device='mps'), 2058 torch.randint(-10, 10, size, dtype=torch.int8, device='mps')) 2059 2060 # Weigths on wrong device 2061 with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"): 2062 torch.nn.functional.linear(torch.rand(size, device='mps'), 2063 torch.rand(size, device='cpu')) 2064 2065 # Input on wrong device 2066 with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"): 2067 torch.nn.functional.linear(torch.rand(size, device='cpu'), 2068 torch.rand(size, device='mps')) 2069 2070 def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False): 2071 cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias) 2072 mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias) 2073 2074 # Use the same weights and bias as the ones from the cpu 2075 mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps") 2076 2077 if bias: 2078 mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps") 2079 2080 linear_mps_input = torch.randn(shape).to('mps') 2081 linear_cpu_input = linear_mps_input.detach().clone().to('cpu') 2082 2083 if backward_pass: 2084 linear_mps_input = linear_mps_input.requires_grad_() 2085 linear_cpu_input = linear_cpu_input.requires_grad_() 2086 2087 linear_cpu_output = cpu_linear(linear_cpu_input) 2088 linear_mps_output = mps_linear(linear_mps_input) 2089 2090 self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu')) 2091 self.assertEqual(linear_cpu_output.size(), linear_mps_output.size()) 2092 2093 if backward_pass: 2094 cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True) 2095 grad = cpu_grad.detach().to('mps').requires_grad_() 2096 2097 linear_cpu_output.backward(gradient=cpu_grad, create_graph=True) 2098 linear_mps_output.backward(gradient=grad, create_graph=True) 2099 2100 self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size()) 2101 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05) 2102 2103 self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size()) 2104 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05) 2105 if bias: 2106 self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size()) 2107 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05) 2108 2109 # test gradgrad 2110 x_grad_out = torch.rand_like(linear_cpu_input) 2111 x_grad_out_mps = x_grad_out.to("mps") 2112 w_grad_out = torch.rand_like(cpu_linear.weight) 2113 w_grad_out_mps = w_grad_out.to("mps") 2114 2115 linear_cpu_input.grad.detach().zero_() 2116 linear_mps_input.grad.detach().zero_() 2117 cpu_linear.weight.grad.detach().zero_() 2118 mps_linear.weight.grad.detach().zero_() 2119 if bias: 2120 b_grad_out = torch.rand_like(cpu_linear.bias) 2121 b_grad_out_mps = b_grad_out.to("mps") 2122 cpu_linear.bias.grad.detach().zero_() 2123 mps_linear.bias.grad.detach().zero_() 2124 2125 linear_cpu_input.grad.backward(x_grad_out, retain_graph=True) 2126 linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True) 2127 cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True) 2128 mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True) 2129 if bias: 2130 cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True) 2131 mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True) 2132 2133 self.assertEqual(cpu_grad.grad, grad.grad) 2134 self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad) 2135 self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad) 2136 if bias: 2137 self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad) 2138 2139 def test_linear1D(self): 2140 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False) 2141 2142 def test_linear1D_backward(self): 2143 self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True) 2144 2145 def test_linear2D(self): 2146 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False) 2147 2148 def test_linear2D_backward(self): 2149 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True) 2150 2151 def test_linear2D_no_bias(self): 2152 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False) 2153 2154 def test_linear2D_no_bias_backward(self): 2155 self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True) 2156 2157 def test_linear3D(self): 2158 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False) 2159 2160 def test_linear3D_backward(self): 2161 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True) 2162 2163 def test_linear3D_no_bias(self): 2164 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False) 2165 2166 def test_linear3D_no_bias_backward(self): 2167 self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True) 2168 2169 def test_uniform(self): 2170 low = torch.zeros(5, 5, requires_grad=True) 2171 high = (torch.ones(5, 5) * 3).requires_grad_() 2172 low_1d = torch.zeros(1, requires_grad=True) 2173 high_1d = (torch.ones(1) * 3).requires_grad_() 2174 self.assertEqual(Uniform(low, high).sample().size(), (5, 5)) 2175 self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5)) 2176 self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,)) 2177 self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1)) 2178 self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,)) 2179 2180 # Check log_prob computation when value outside range 2181 uniform = Uniform(low_1d, high_1d, validate_args=False) 2182 above_high = torch.tensor([4.0]) 2183 below_low = torch.tensor([-1.0]) 2184 self.assertEqual(uniform.log_prob(above_high).item(), -inf) 2185 self.assertEqual(uniform.log_prob(below_low).item(), -inf) 2186 2187 # check cdf computation when value outside range 2188 self.assertEqual(uniform.cdf(below_low).item(), 0) 2189 self.assertEqual(uniform.cdf(above_high).item(), 1) 2190 2191 state = torch.get_rng_state() 2192 rand = low.new(low.size()).uniform_() 2193 torch.set_rng_state(state) 2194 u = Uniform(low, high).rsample() 2195 u.backward(torch.ones_like(u)) 2196 self.assertEqual(low.grad, 1 - rand) 2197 self.assertEqual(high.grad, rand) 2198 low.grad.zero_() 2199 high.grad.zero_() 2200 2201 def test_randperm(self, device="mps"): 2202 rng_device = None 2203 for n in (5, 100, 50000, 100000): 2204 for dtype in (torch.long, torch.half, torch.float): 2205 if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here. 2206 continue 2207 if n > 256 and dtype == torch.bfloat16: 2208 continue 2209 with torch.random.fork_rng(devices=rng_device): 2210 res1 = torch.randperm(n, dtype=dtype, device=device) 2211 res2 = torch.empty(0, dtype=dtype, device=device) 2212 torch.randperm(n, out=res2, dtype=dtype, device=device) 2213 self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device)) 2214 2215 # Default type is long 2216 for n in (100, 10000): 2217 self.assertEqual(torch.randperm(n, device=device).dtype, torch.long) 2218 2219 # randperm of 0 elements is an empty tensor 2220 res1 = torch.randperm(0) 2221 res2 = torch.tensor(5, dtype=dtype, device=device) 2222 torch.randperm(0, out=res2) 2223 self.assertEqual(res1.numel(), 0) 2224 self.assertEqual(res2.numel(), 0) 2225 2226 # Test non-contiguous tensors 2227 for n in (4, 5, 6, 10, 20): 2228 non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t() 2229 self.assertFalse(non_contiguous_tensor.is_contiguous()) 2230 with torch.random.fork_rng(devices=rng_device): 2231 res = torch.randperm(n, dtype=torch.long, device=device) 2232 torch.randperm(n, out=non_contiguous_tensor) 2233 self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device)) 2234 2235 # Test forward maxpool2d 2236 def test_max_pool2d(self): 2237 def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False): 2238 2239 cpu_x = None 2240 if (test_ties): 2241 cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True) 2242 else: 2243 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 2244 x = cpu_x.detach().clone().to('mps').requires_grad_() 2245 2246 pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation, 2247 ceil_mode=ceil_mode, return_indices=return_indices) 2248 2249 if (return_indices is False): 2250 y = pool(x) 2251 ref_y = pool(cpu_x) 2252 2253 cpu_grad = torch.ones_like(ref_y) 2254 grad = cpu_grad.to('mps') 2255 2256 y.backward(gradient=grad) 2257 ref_y.backward(gradient=cpu_grad) 2258 2259 self.assertEqual(y, ref_y) 2260 self.assertEqual(x.grad, cpu_x.grad) 2261 else: 2262 y, idx = pool(x) 2263 ref_y, ref_idx = pool(cpu_x) 2264 2265 cpu_grad = torch.ones_like(ref_y) 2266 grad = cpu_grad.to('mps') 2267 2268 y.backward(gradient=grad) 2269 ref_y.backward(gradient=cpu_grad) 2270 2271 self.assertEqual(y, ref_y) 2272 self.assertEqual(idx, ref_idx) 2273 self.assertEqual(x.grad, cpu_x.grad) 2274 2275 # Test with no batch dimension 2276 helper((8, 4, 4), ks=2) 2277 helper((2, 8, 4, 4), ks=2) 2278 helper((1, 1000, 32, 32), ks=4) 2279 helper((1, 1000, 1, 4), ks=(1, 4)) # test for max_pool1d 2280 # Test padding 2281 helper((1, 1000, 32, 32), ks=4, padding=1) 2282 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1)) # test for max_pool1d 2283 # Test dilation 2284 helper((1, 1000, 32, 32), ks=4, dilation=2) 2285 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2)) # test for max_pool1d 2286 # Test ceil mode 2287 helper((1, 1000, 32, 32), ks=4, ceil_mode=True) 2288 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True) # test for max_pool1d 2289 2290 # Test return indices 2291 for test_ties in [False, True]: 2292 # Test with no batch dimension 2293 helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties) 2294 helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties) 2295 helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties) 2296 helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties) # test for max_pool1d 2297 # Test padding 2298 helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties) 2299 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1), 2300 return_indices=True, test_ties=test_ties) # test for max_pool1d 2301 # Test dilation 2302 helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties) 2303 helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2), 2304 return_indices=True, test_ties=test_ties) # test for max_pool1d 2305 # Test ceil mode 2306 helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties) 2307 helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True, 2308 return_indices=True, test_ties=test_ties) # test for max_pool1d 2309 2310 def test_adaptive_avg_pool2d_output_size_one(self): 2311 def helper(size, memory_format): 2312 x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True) 2313 if memory_format == 'non_contiguous': 2314 x = x[::2, ::2, ::2, ::2] 2315 else: 2316 x = x.to(memory_format=memory_format) 2317 2318 net = torch.nn.AdaptiveAvgPool2d((1, 1)) 2319 out = net(x) 2320 ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1)) 2321 2322 out.sum().backward() # make sure it doesn't crash 2323 2324 self.assertEqual(out, ref_out) 2325 if memory_format == torch.channels_last: 2326 self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 2327 c = out.size(1) 2328 self.assertEqual(out.stride(), [c, 1, c, c]) 2329 else: 2330 self.assertTrue(out.is_contiguous()) 2331 c = out.size(1) 2332 self.assertEqual(out.stride(), [c, 1, 1, 1]) 2333 2334 helper((2, 3, 6, 6), torch.contiguous_format) 2335 2336 def test_masked_scatter(self): 2337 def helper(shape): 2338 x_mps = torch.randn(shape, device="mps") 2339 x_cpu = x_mps.detach().clone().cpu() 2340 2341 mask_mps = torch.rand(shape, device="mps") < 0.6 2342 mask_cpu = mask_mps.detach().clone().cpu() 2343 2344 y_mps = torch.randn(shape, device="mps") 2345 y_cpu = y_mps.detach().clone().cpu() 2346 2347 y_mps.masked_scatter_(mask_mps, x_mps) 2348 y_cpu.masked_scatter_(mask_cpu, x_cpu) 2349 2350 self.assertEqual(y_mps, y_cpu) 2351 helper([2, 5]) 2352 helper([10, 10]) 2353 helper([5, 10, 3]) 2354 helper([10, 5, 10, 3]) 2355 helper([10, 5, 10, 3, 20]) 2356 2357 def test_masked_fill(self): 2358 device = "mps" 2359 dtype = torch.float32 2360 mask_dtype = torch.bool 2361 num_dest = 10 2362 2363 dst = torch.zeros(num_dest, dtype=dtype, device=device) 2364 mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device) 2365 val = random.random() 2366 dst2 = torch.zeros(num_dest, dtype=dtype) 2367 mask_cpu = mask.to("cpu") 2368 2369 dst.masked_fill_(mask, val) 2370 for i in range(num_dest): 2371 if mask_cpu[i]: 2372 dst2[i] = val 2373 self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0) 2374 2375 def test_masked_fill__non_contiguous(self): 2376 shape = (3, 5) 2377 2378 x_mps = torch.randn(shape, device="mps") 2379 x_cpu = x_mps.detach().clone().cpu() 2380 mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool) 2381 mask_cpu = mask_mps.detach().clone().cpu() 2382 2383 x_mps_strided = x_mps.T 2384 x_cpu_strided = x_cpu.T 2385 2386 x_mps_strided.masked_fill_(mask_mps.T, float("-inf")) 2387 x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf")) 2388 2389 self.assertEqual(x_mps_strided, x_cpu_strided) 2390 self.assertFalse((x_mps_strided == float("-inf")).any()) 2391 2392 def test_nhwc_operation(self): 2393 def helper(shape, channels_last=False): 2394 import numpy as np 2395 np.random.seed(332) 2396 arr = (256 - 128) * np.random.random_sample(size=shape) + 128 2397 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) 2398 if (channels_last): 2399 cpu_x = cpu_x.to(memory_format=torch.channels_last) 2400 cpu_x.retain_grad() 2401 x = cpu_x.detach().clone().to('mps').requires_grad_() 2402 2403 # This passes 2404 self.assertEqual(x, cpu_x) 2405 2406 helper((2, 2, 2, 2), True) 2407 2408 # Test forward batch norm 2409 def test_batch_norm(self): 2410 def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False, 2411 track_running_stats=True, test_module=False): 2412 2413 import numpy as np 2414 np.random.seed(332) 2415 arr = (256 - 128) * np.random.random_sample(size=shape) + 128 2416 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) 2417 if (channels_last): 2418 cpu_x = cpu_x.to(memory_format=torch.channels_last) 2419 cpu_x.retain_grad() 2420 x = cpu_x.detach().clone().to('mps').requires_grad_() 2421 2422 mean_shape = [shape[1]] 2423 cpu_running_mean = None 2424 cpu_running_var = None 2425 running_mean = None 2426 running_var = None 2427 if (track_running_stats): 2428 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140 2429 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float) 2430 var_arr = 32 * np.random.random_sample(size=mean_shape) 2431 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float) 2432 running_mean = cpu_running_mean.detach().clone().to('mps') 2433 running_var = cpu_running_var.detach().clone().to('mps') 2434 2435 weight = None 2436 cpu_weight = None 2437 bias = None 2438 cpu_bias = None 2439 if (wts): 2440 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2441 weight = cpu_weight.detach().clone().to('mps').requires_grad_() 2442 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2443 bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2444 2445 y = None 2446 ref_y = None 2447 2448 if (not test_module): 2449 y = torch.nn.functional.batch_norm(x, running_mean, running_var, 2450 weight=weight, 2451 bias=bias, 2452 training=training, 2453 momentum=momentum, eps=eps) 2454 ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var, 2455 weight=cpu_weight, 2456 bias=cpu_bias, 2457 training=training, 2458 momentum=momentum, eps=eps) 2459 2460 else: 2461 2462 batchnorm_op = None 2463 mps_batchnorm_op = None 2464 2465 if (len(shape) == 3): 2466 batchnorm_op = torch.nn.BatchNorm1d(shape[1], 2467 eps=eps, 2468 momentum=momentum, 2469 affine=wts, 2470 track_running_stats=track_running_stats, 2471 device='cpu') 2472 mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1], 2473 eps=eps, 2474 momentum=momentum, 2475 affine=wts, 2476 track_running_stats=track_running_stats, 2477 device='mps') 2478 elif (len(shape) == 4): 2479 batchnorm_op = torch.nn.BatchNorm2d(shape[1], 2480 eps=eps, 2481 momentum=momentum, 2482 affine=wts, 2483 track_running_stats=track_running_stats, 2484 device='cpu') 2485 mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1], 2486 eps=eps, 2487 momentum=momentum, 2488 affine=wts, 2489 track_running_stats=track_running_stats, 2490 device='mps') 2491 elif (len(shape) == 5): 2492 batchnorm_op = torch.nn.BatchNorm3d(shape[1], 2493 eps=eps, 2494 momentum=momentum, 2495 affine=wts, 2496 track_running_stats=track_running_stats, 2497 device='cpu') 2498 mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1], 2499 eps=eps, 2500 momentum=momentum, 2501 affine=wts, 2502 track_running_stats=track_running_stats, 2503 device='mps') 2504 2505 if (track_running_stats): 2506 batchnorm_op.running_mean = cpu_running_mean 2507 batchnorm_op.running_var = cpu_running_var 2508 mps_batchnorm_op.running_mean = running_mean 2509 mps_batchnorm_op.running_var = running_var 2510 if (wts): 2511 batchnorm_op.weight = torch.nn.Parameter(cpu_weight) 2512 batchnorm_op.bias = torch.nn.Parameter(cpu_bias) 2513 mps_batchnorm_op.weight = torch.nn.Parameter(weight) 2514 mps_batchnorm_op.bias = torch.nn.Parameter(bias) 2515 2516 ref_y = batchnorm_op(cpu_x) 2517 y = mps_batchnorm_op(x) 2518 2519 self.assertEqual(y, ref_y) 2520 if (not test_module): 2521 self.assertEqual(running_mean, cpu_running_mean) 2522 self.assertEqual(running_var, cpu_running_var) 2523 else: 2524 self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean) 2525 self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var) 2526 2527 cpu_grad = torch.randn(ref_y.shape) 2528 grad = cpu_grad.to('mps') 2529 ref_y.backward(gradient=cpu_grad) 2530 y.backward(gradient=grad) 2531 2532 self.assertEqual(x.grad, cpu_x.grad) 2533 if (wts): 2534 if (not test_module): 2535 self.assertEqual(weight.grad, cpu_weight.grad) 2536 self.assertEqual(bias.grad, cpu_bias.grad) 2537 else: 2538 self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad) 2539 self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad) 2540 2541 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]: 2542 for test_module in [False, True]: 2543 for track_running_stats in [True, False]: 2544 for channels_last in [False]: 2545 if (channels_last and len(shape) != 4): 2546 continue 2547 # Running stats must be tracked in eval mode 2548 if (track_running_stats): 2549 helper(shape, eps=0, momentum=1, channels_last=channels_last, 2550 track_running_stats=track_running_stats, test_module=test_module) 2551 helper(shape, channels_last=channels_last, 2552 track_running_stats=track_running_stats, test_module=test_module) 2553 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last, 2554 track_running_stats=track_running_stats, test_module=test_module) 2555 helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last, 2556 track_running_stats=track_running_stats, test_module=test_module) 2557 helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last, 2558 track_running_stats=track_running_stats, test_module=test_module) 2559 helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last, 2560 track_running_stats=track_running_stats, test_module=test_module) 2561 helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last, 2562 track_running_stats=track_running_stats, test_module=test_module) 2563 helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last, 2564 track_running_stats=track_running_stats, test_module=test_module) 2565 helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last, 2566 track_running_stats=track_running_stats, test_module=test_module) 2567 helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last, 2568 track_running_stats=track_running_stats, test_module=test_module) 2569 2570 def test_batch_norm_backward(self): 2571 inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True) 2572 x = torch.nn.BatchNorm2d(8).to("mps") 2573 y = torch.nn.BatchNorm2d(8).to("mps") 2574 y.weight.requires_grad = False 2575 y.bias.requires_grad = False 2576 outputs = y(x(inputs)) 2577 # This used to crash, see https://github.com/pytorch/pytorch/issues/98602 2578 outputs.sum().backward() 2579 2580 # Regression test for https://github.com/pytorch/pytorch/issues/133520 2581 def test_batch_norm_slices(self): 2582 bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu') 2583 bn_mps = nn.BatchNorm2d(100, affine=False, device='mps') 2584 2585 x_cpu = torch.randn(100, 100, 35, 45).to('cpu') 2586 x_mps = x_cpu.to('mps') 2587 2588 res_cpu = bn_cpu(x_cpu[5:]) 2589 res_mps = bn_mps(x_mps[5:]) 2590 2591 self.assertEqual(res_cpu, res_mps) 2592 2593 def test_layer_norm_backward(self): 2594 inputs = torch.rand(4, 4, device="mps", requires_grad=True) 2595 x = torch.nn.LayerNorm(4).to("mps") 2596 y = torch.nn.LayerNorm(4).to("mps") 2597 y.weight.requires_grad = False 2598 y.bias.requires_grad = False 2599 outputs = y(x(inputs)) 2600 # This used to crash, see https://github.com/pytorch/pytorch/issues/98602 2601 outputs.sum().backward() 2602 2603 def test_norm(self): 2604 a = torch.arange(9, dtype=torch.float, device="mps") - 4 2605 b = a.reshape((3, 3)) 2606 2607 a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4 2608 b_cpu = a_cpu.reshape((3, 3)) 2609 2610 res = torch.norm(a) 2611 res_cpu = torch.norm(a_cpu) 2612 self.assertEqual(res, res_cpu) 2613 2614 res = torch.norm(b) 2615 res_cpu = torch.norm(b_cpu) 2616 self.assertEqual(res, res_cpu) 2617 2618 res = torch.norm(a, float('inf')) 2619 res_cpu = torch.norm(a_cpu, float('inf')) 2620 self.assertEqual(res, res_cpu) 2621 2622 res = torch.norm(b, float('inf')) 2623 res_cpu = torch.norm(b_cpu, float('inf')) 2624 self.assertEqual(res, res_cpu) 2625 2626 c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps") 2627 c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu") 2628 2629 res = torch.norm(c, dim=0) 2630 res_cpu = torch.norm(c_cpu, dim=0) 2631 self.assertEqual(res, res_cpu) 2632 2633 res = torch.norm(c, dim=1) 2634 res_cpu = torch.norm(c_cpu, dim=1) 2635 self.assertEqual(res, res_cpu) 2636 2637 res = torch.norm(c, p=1, dim=1) 2638 res_cpu = torch.norm(c_cpu, p=1, dim=1) 2639 self.assertEqual(res, res_cpu) 2640 2641 d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2) 2642 d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2) 2643 2644 res = torch.norm(d, dim=(1, 2)) 2645 res_cpu = torch.norm(d_cpu, dim=(1, 2)) 2646 self.assertEqual(res, res_cpu) 2647 2648 res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) 2649 res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :]) 2650 self.assertEqual(res, res_cpu) 2651 2652 def test_linalg_vector_norm(self): 2653 x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps") 2654 x_cpu = x_mps.detach().clone().cpu() 2655 2656 res_mps = torch.linalg.vector_norm(x_mps, ord=0) 2657 res_cpu = torch.linalg.vector_norm(x_cpu, ord=0) 2658 self.assertEqual(res_mps, res_cpu) 2659 2660 a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4 2661 a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4 2662 2663 B_mps = a_mps.reshape(3, 3, 3) 2664 B_cpu = a_cpu.reshape(3, 3, 3) 2665 2666 res_mps = torch.linalg.vector_norm(a_mps, ord=3.5) 2667 res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5) 2668 self.assertEqual(res_mps, res_cpu) 2669 2670 res_mps = torch.linalg.vector_norm(B_mps, ord=3.5) 2671 res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5) 2672 self.assertEqual(res_mps, res_cpu) 2673 2674 for dim in range(0, B_mps.dim()): 2675 res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim) 2676 res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim) 2677 self.assertEqual(res_mps, res_cpu) 2678 2679 2680 def test_layer_norm(self): 2681 # TODO: Test non-contiguous 2682 def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32): 2683 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True) 2684 x = cpu_x.detach().clone().to('mps').requires_grad_() 2685 2686 cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype) 2687 mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype) 2688 cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True) 2689 wt = cpu_wt.detach().clone().to('mps').requires_grad_() 2690 cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True) 2691 bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2692 2693 if (elementwise_affine): 2694 cpu_op.weight = torch.nn.Parameter(cpu_wt) 2695 mps_op.weight = torch.nn.Parameter(wt) 2696 cpu_op.bias = torch.nn.Parameter(cpu_bias) 2697 mps_op.bias = torch.nn.Parameter(bias) 2698 2699 cpu_result = cpu_op(cpu_x) 2700 result = mps_op(x) 2701 2702 cpu_grad = torch.randn(cpu_result.shape) 2703 grad = cpu_grad.to('mps') 2704 2705 cpu_result.backward(cpu_grad) 2706 result.backward(grad) 2707 2708 self.assertEqual(result, cpu_result) 2709 self.assertEqual(x.grad, cpu_x.grad) 2710 if (elementwise_affine): 2711 self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad) 2712 self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad) 2713 2714 for elementwise_affine in [True, False]: 2715 helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine) 2716 helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine) 2717 helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine) 2718 2719 # Regression test for https://github.com/pytorch/pytorch/issues/96113 2720 torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16)) 2721 2722 @xfailIf(product_version < 14.0) 2723 def test_ifft(self): 2724 # See: https://github.com/pytorch/pytorch/issues/124096 2725 device = torch.device("mps") 2726 2727 N = 64 2728 signal = torch.rand(N, device=device) 2729 fft_result = torch.fft.rfft(signal) 2730 ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0]) 2731 2732 # Expecting the inverted to yield the original signal 2733 self.assertEqual(ifft_result, signal) 2734 2735 # Regression test for https://github.com/pytorch/pytorch/issues/135223 2736 def test_fftfreq(self): 2737 freq_cpu = torch.fft.fftfreq(10**4, device='cpu') 2738 freq_mps = torch.fft.fftfreq(10**4, device='mps') 2739 self.assertEqual(freq_cpu, freq_mps) 2740 2741 def test_instance_norm(self): 2742 def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False): 2743 2744 import numpy as np 2745 np.random.seed(332) 2746 arr = (256 - 128) * np.random.random_sample(size=shape) + 128 2747 cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True) 2748 if (channels_last): 2749 cpu_x = cpu_x.to(memory_format=torch.channels_last) 2750 cpu_x.retain_grad() 2751 x = cpu_x.detach().clone().to('mps').requires_grad_() 2752 2753 mean_shape = [shape[1]] 2754 cpu_running_mean = None 2755 cpu_running_var = None 2756 running_mean = None 2757 running_var = None 2758 if (track_running_stats): 2759 mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140 2760 cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float) 2761 var_arr = 32 * np.random.random_sample(size=mean_shape) 2762 cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float) 2763 running_mean = cpu_running_mean.detach().clone().to('mps') 2764 running_var = cpu_running_var.detach().clone().to('mps') 2765 2766 weight = None 2767 cpu_weight = None 2768 bias = None 2769 cpu_bias = None 2770 if (wts): 2771 cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2772 weight = cpu_weight.detach().clone().to('mps').requires_grad_() 2773 cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True) 2774 bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2775 2776 y = None 2777 ref_y = None 2778 2779 if (not test_module): 2780 ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var, 2781 weight=cpu_weight, 2782 bias=cpu_bias, 2783 momentum=momentum, eps=eps) 2784 y = torch.nn.functional.instance_norm(x, running_mean, running_var, 2785 weight=weight, 2786 bias=bias, 2787 momentum=momentum, eps=eps) 2788 2789 else: 2790 2791 instancenorm_op = None 2792 mps_instancenorm_op = None 2793 2794 if (len(shape) == 3): 2795 instancenorm_op = torch.nn.InstanceNorm1d(shape[1], 2796 eps=eps, 2797 momentum=momentum, 2798 affine=wts, 2799 track_running_stats=track_running_stats, 2800 device='cpu') 2801 mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1], 2802 eps=eps, 2803 momentum=momentum, 2804 affine=wts, 2805 track_running_stats=track_running_stats, 2806 device='mps') 2807 elif (len(shape) == 4): 2808 instancenorm_op = torch.nn.InstanceNorm2d(shape[1], 2809 eps=eps, 2810 momentum=momentum, 2811 affine=wts, 2812 track_running_stats=track_running_stats, 2813 device='cpu') 2814 mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1], 2815 eps=eps, 2816 momentum=momentum, 2817 affine=wts, 2818 track_running_stats=track_running_stats, 2819 device='mps') 2820 elif (len(shape) == 5): 2821 instancenorm_op = torch.nn.InstanceNorm3d(shape[1], 2822 eps=eps, 2823 momentum=momentum, 2824 affine=wts, 2825 track_running_stats=track_running_stats, 2826 device='cpu') 2827 mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1], 2828 eps=eps, 2829 momentum=momentum, 2830 affine=wts, 2831 track_running_stats=track_running_stats, 2832 device='mps') 2833 2834 if (track_running_stats): 2835 instancenorm_op.running_mean = cpu_running_mean 2836 instancenorm_op.running_var = cpu_running_var 2837 mps_instancenorm_op.running_mean = running_mean 2838 mps_instancenorm_op.running_var = running_var 2839 if (wts): 2840 instancenorm_op.weight = torch.nn.Parameter(cpu_weight) 2841 instancenorm_op.bias = torch.nn.Parameter(cpu_bias) 2842 mps_instancenorm_op.weight = torch.nn.Parameter(weight) 2843 mps_instancenorm_op.bias = torch.nn.Parameter(bias) 2844 2845 ref_y = instancenorm_op(cpu_x) 2846 y = mps_instancenorm_op(x) 2847 2848 self.assertEqual(y, ref_y) 2849 if (not test_module): 2850 self.assertEqual(running_mean, cpu_running_mean) 2851 self.assertEqual(running_var, cpu_running_var) 2852 else: 2853 self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean) 2854 self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var) 2855 2856 cpu_grad = torch.randn(ref_y.shape) 2857 grad = cpu_grad.to('mps') 2858 ref_y.backward(gradient=cpu_grad) 2859 y.backward(gradient=grad) 2860 2861 self.assertEqual(x.grad, cpu_x.grad) 2862 if (wts): 2863 if (not test_module): 2864 self.assertEqual(weight.grad, cpu_weight.grad) 2865 self.assertEqual(bias.grad, cpu_bias.grad) 2866 else: 2867 self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad) 2868 self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad) 2869 2870 for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]: 2871 for test_module in [False, True]: 2872 for track_running_stats in [True, False]: 2873 for channels_last in [False]: 2874 if (channels_last and len(shape) != 4): 2875 continue 2876 # Running stats must be tracked in eval mode 2877 if (track_running_stats): 2878 helper(shape, eps=0, momentum=1, channels_last=channels_last, 2879 track_running_stats=track_running_stats, test_module=test_module) 2880 helper(shape, channels_last=channels_last, 2881 track_running_stats=track_running_stats, test_module=test_module) 2882 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last, 2883 track_running_stats=track_running_stats, test_module=test_module) 2884 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last, 2885 track_running_stats=track_running_stats, test_module=test_module) 2886 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last, 2887 track_running_stats=track_running_stats, test_module=test_module) 2888 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last, 2889 track_running_stats=track_running_stats, test_module=test_module) 2890 helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last, 2891 track_running_stats=track_running_stats, test_module=test_module) 2892 helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last, 2893 track_running_stats=track_running_stats, test_module=test_module) 2894 helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last, 2895 track_running_stats=track_running_stats, test_module=test_module) 2896 helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last, 2897 track_running_stats=track_running_stats, test_module=test_module) 2898 2899 def test_weight_norm(self): 2900 def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim): 2901 cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim) 2902 norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim) 2903 2904 cpu_out = cpu_norm(cpu_x) 2905 out = norm(x) 2906 2907 self.assertEqual(cpu_out, out) 2908 2909 cpu_grad = torch.randn(cpu_out.shape) 2910 grad = cpu_grad.to('mps') 2911 cpu_out.backward(gradient=cpu_grad) 2912 out.backward(gradient=grad) 2913 2914 self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad) 2915 self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad) 2916 2917 self.assertEqual(x.grad, cpu_x.grad) 2918 2919 def helper(dim, layer='linear', dtype=torch.float32): 2920 # linear layer 2921 if layer == 'linear': 2922 cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True) 2923 x = cpu_x.detach().clone().to('mps').requires_grad_() 2924 2925 cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True) 2926 weight = cpu_weight.detach().clone().to('mps').requires_grad_() 2927 2928 cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True) 2929 bias = cpu_bias.detach().clone().to('mps').requires_grad_() 2930 2931 cpu_linear = torch.nn.Linear(5, 10, device='cpu') 2932 linear = torch.nn.Linear(5, 10, device='mps') 2933 2934 with torch.no_grad(): 2935 cpu_linear.weight.copy_(cpu_weight) 2936 cpu_linear.bias.copy_(cpu_bias) 2937 linear.weight.copy_(weight) 2938 linear.bias.copy_(bias) 2939 validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim) 2940 2941 # conv layer 2942 if layer == 'conv': 2943 cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True) 2944 x = cpu_x.detach().clone().to('mps').requires_grad_() 2945 2946 cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu') 2947 conv = torch.nn.Conv2d(3, 3, 3, device='mps') 2948 2949 with torch.no_grad(): 2950 conv.weight.copy_(cpu_conv.weight) 2951 conv.bias.copy_(cpu_conv.bias) 2952 2953 validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim) 2954 2955 # conv3d layer 2956 if layer == 'conv3d': 2957 cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True) 2958 x = cpu_x.detach().clone().to('mps').requires_grad_() 2959 2960 cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu') 2961 conv = torch.nn.Conv3d(3, 3, 3, device='mps') 2962 2963 with torch.no_grad(): 2964 conv.weight.copy_(cpu_conv.weight) 2965 conv.bias.copy_(cpu_conv.bias) 2966 2967 validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim) 2968 2969 helper(0, layer='linear') 2970 helper(1, layer='linear') 2971 helper(-1, layer='linear') 2972 2973 helper(0, layer='conv') 2974 helper(1, layer='conv') 2975 helper(2, layer='conv') 2976 helper(3, layer='conv') 2977 helper(-1, layer='conv') 2978 2979 if product_version >= 13.2: 2980 # Conv3d is only available from MacOS 13 onwards 2981 helper(0, layer='conv3d') 2982 helper(1, layer='conv3d') 2983 helper(2, layer='conv3d') 2984 helper(3, layer='conv3d') 2985 helper(4, layer='conv3d') 2986 helper(-1, layer='conv3d') 2987 2988 # Test conv2d 2989 def test_conv2d_unit(self): 2990 def helper(input_shape, wt_shape, 2991 stride=1, padding=0, 2992 dilation=1, groups=1, 2993 bias_shape=None): 2994 2995 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True) 2996 x = cpu_x.detach().clone().to('mps').requires_grad_() 2997 2998 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True) 2999 wt = cpu_wt.detach().clone().to('mps').requires_grad_() 3000 3001 cpu_bias = None 3002 bias = None 3003 3004 if (bias_shape is not None): 3005 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True) 3006 bias = cpu_bias.detach().clone().to('mps').requires_grad_() 3007 3008 y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride, 3009 padding=padding, dilation=dilation, groups=groups) 3010 ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride, 3011 padding=padding, dilation=dilation, groups=groups) 3012 3013 cpu_grad = torch.ones_like(ref_y) 3014 grad = cpu_grad.to('mps') 3015 3016 y.backward(gradient=grad) 3017 ref_y.backward(gradient=cpu_grad) 3018 3019 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04) 3020 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05) 3021 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05) 3022 if (bias_shape is not None): 3023 self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05) 3024 3025 N = 1 3026 C_in = 3 3027 C_out = 64 3028 H = 64 3029 W = 64 3030 kH = 4 3031 kW = 4 3032 stride = 2 3033 padding = 1 3034 3035 helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding) 3036 3037 N = 4 3038 C_in = 16 3039 H = 32 3040 W = 32 3041 3042 C_out = 8 3043 kH = 3 3044 kW = 3 3045 3046 for groups in [1, 2, 4]: 3047 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups) 3048 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups) 3049 3050 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups) 3051 helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups) 3052 3053 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups) 3054 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups) 3055 3056 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, 3057 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups) 3058 helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, 3059 kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups) 3060 3061 # Test conv transpose 2d 3062 def test_conv_transpose2d(self): 3063 def helper(input_shape, wt_shape, 3064 stride=1, padding=0, 3065 output_padding=0, 3066 dilation=1, groups=1, 3067 bias_shape=None): 3068 3069 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True) 3070 x = cpu_x.detach().clone().to('mps').requires_grad_() 3071 3072 cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True) 3073 wt = cpu_wt.detach().clone().to('mps').requires_grad_() 3074 3075 cpu_bias = None 3076 bias = None 3077 3078 if (bias_shape is not None): 3079 cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True) 3080 bias = cpu_bias.detach().clone().to('mps').requires_grad_() 3081 3082 y = torch.nn.functional.conv_transpose2d( 3083 x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 3084 ref_y = torch.nn.functional.conv_transpose2d( 3085 cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding, 3086 output_padding=output_padding, groups=groups, dilation=dilation) 3087 3088 cpu_grad = torch.randn(ref_y.shape) 3089 grad = cpu_grad.to('mps') 3090 3091 y.backward(gradient=grad) 3092 ref_y.backward(gradient=cpu_grad) 3093 3094 self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04) 3095 self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05) 3096 self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05) 3097 3098 # if (bias_shape is not None): 3099 # print(cpu_bias.grad) 3100 # print(bias.grad.to('cpu')) 3101 # self.assertEqual(bias.grad, cpu_bias.grad) 3102 3103 N = 4 3104 C_in = 2 3105 H = 32 3106 W = 32 3107 3108 C_out = 8 3109 groups = 1 3110 kH = 3 3111 kW = 3 3112 3113 for stride in [1, 2, 3]: 3114 for padding in [0, 1, 2]: 3115 for output_padding in [0, 1, 2]: 3116 for dilation in [1, 2]: 3117 if (output_padding >= stride or output_padding >= dilation): 3118 continue 3119 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride, 3120 padding=padding, output_padding=output_padding, dilation=dilation) 3121 helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride, 3122 padding=padding, output_padding=output_padding, dilation=dilation) 3123 3124 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride, 3125 padding=padding, output_padding=output_padding, dilation=dilation) 3126 helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride, 3127 padding=padding, output_padding=output_padding, dilation=dilation) 3128 3129 # Test sigmoid 3130 def test_sigmoid(self): 3131 def helper(shape): 3132 3133 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 3134 x = cpu_x.detach().clone().to('mps').requires_grad_() 3135 3136 sigmoid_op = torch.nn.Sigmoid() 3137 3138 y = sigmoid_op(x) 3139 ref_y = sigmoid_op(cpu_x) 3140 3141 cpu_grad = torch.ones_like(ref_y) 3142 grad = cpu_grad.to('mps') 3143 3144 y.backward(gradient=grad) 3145 ref_y.backward(gradient=cpu_grad) 3146 3147 self.assertEqual(y, ref_y) 3148 self.assertEqual(x.grad, cpu_x.grad) 3149 3150 helper((2, 3, 4, 5)) 3151 helper((2, 3, 4)) 3152 helper((2, 8, 4, 5)) 3153 3154 # Test tanh 3155 def test_tanh(self): 3156 def helper(shape): 3157 3158 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 3159 x = cpu_x.detach().clone().to('mps').requires_grad_() 3160 3161 tanh_op = torch.nn.Tanh() 3162 3163 y = tanh_op(x) 3164 ref_y = tanh_op(cpu_x) 3165 3166 cpu_grad = torch.ones_like(ref_y) 3167 grad = cpu_grad.to('mps') 3168 3169 y.backward(gradient=grad) 3170 ref_y.backward(gradient=cpu_grad) 3171 3172 self.assertEqual(y, ref_y) 3173 self.assertEqual(x.grad, cpu_x.grad) 3174 3175 helper((2, 3, 4, 5)) 3176 helper((2, 3, 4)) 3177 helper((2, 8, 4, 5)) 3178 3179 def test_threshold(self): 3180 def helper(threshold, value, num_elems, inplace=False, requires_grad=True): 3181 m = nn.Threshold(threshold=threshold, value=value, inplace=inplace) 3182 3183 input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float) 3184 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad) 3185 3186 output_cpu = m(input_cpu) 3187 output_mps = m(input_mps) 3188 3189 cpu_grad = torch.ones_like(output_cpu) 3190 mps_grad = cpu_grad.to('mps') 3191 3192 self.assertEqual(output_cpu, output_mps) 3193 3194 if requires_grad: 3195 output_cpu.backward(gradient=cpu_grad) 3196 output_mps.backward(gradient=mps_grad) 3197 3198 self.assertEqual(input_cpu.grad, input_mps.grad) 3199 3200 helper(threshold=0.1, value=20, num_elems=2) 3201 helper(threshold=-0.1, value=10, num_elems=10) 3202 helper(threshold=0.5, value=-15, num_elems=100) 3203 helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False) 3204 3205 # Test pow 3206 def test_pow(self): 3207 def helper(shape): 3208 # aten::pow.Tensor_Tensor 3209 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3210 x = cpu_x.detach().clone().to('mps') 3211 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3212 y = cpu_y.detach().clone().to('mps') 3213 z = torch.pow(x, y) 3214 ref_z = torch.pow(cpu_x, cpu_y) 3215 3216 self.assertEqual(z, ref_z) 3217 3218 # aten::pow.Tensor_Scalar 3219 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3220 x = cpu_x.detach().clone().to('mps') 3221 exp = random.random() 3222 z = torch.pow(x, exp) 3223 ref_z = torch.pow(cpu_x, exp) 3224 3225 self.assertEqual(z, ref_z) 3226 3227 # aten::pow.Scalar 3228 x = random.random() 3229 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3230 y = cpu_y.detach().clone().to('mps') 3231 z = torch.pow(x, y) 3232 ref_z = torch.pow(x, cpu_y) 3233 3234 self.assertEqual(z, ref_z) 3235 3236 helper((2, 8, 4, 5)) 3237 3238 # Test addcmul 3239 def test_addcmul(self): 3240 def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None): 3241 def rand_helper(dtype): 3242 if dtype.is_floating_point: 3243 return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 3244 return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False) 3245 3246 cpu_x = rand_helper(xtype) 3247 x = cpu_x.detach().clone().to('mps') 3248 3249 cpu_y = rand_helper(ytype if ytype is not None else xtype) 3250 y = cpu_y.detach().clone().to('mps') 3251 3252 cpu_z = rand_helper(ztype if ztype is not None else xtype) 3253 z = cpu_z.detach().clone().to('mps') 3254 3255 y = torch.addcmul(x, y, z, value=value) 3256 ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value) 3257 3258 self.assertEqual(y, ref_y) 3259 3260 helper((2, 3, 4, 5), 0.1) 3261 helper((2, 8, 4, 5), 0.1) 3262 helper((2, 3, 4, 5), 0.2) 3263 helper((2, 8, 4, 5), 0.2) 3264 # Integral types 3265 helper((2, 2), 1.0, xtype=torch.int32) 3266 helper((2, 2), 2.0, xtype=torch.int16) 3267 3268 # Mixed types 3269 helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32) 3270 helper((3, 2), 1.0, ytype=torch.float16) 3271 helper((2, 3), 1.0, ztype=torch.float16) 3272 helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8) 3273 helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8) 3274 3275 # Test addcdiv 3276 def test_addcdiv(self): 3277 def helper(shape, value): 3278 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3279 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3280 # clamp to avoid division by 0 3281 cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1) 3282 cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 3283 3284 mps_x = cpu_x.detach().clone().to('mps') 3285 mps_y = cpu_y.detach().clone().to('mps') 3286 mps_z = cpu_z.detach().clone().to('mps') 3287 mps_out = cpu_out.detach().clone().to('mps') 3288 3289 result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value) 3290 result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value) 3291 self.assertEqual(result_div_mps, result_div_cpu) 3292 # test .out variant 3293 self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu) 3294 3295 helper((2, 3, 4, 5), 0.1) 3296 helper((2, 8, 4, 5), 0.2) 3297 helper((2, 3, 4, 5), 1.0) # value of 1 should be ignored internally 3298 3299 def test_addcdiv_transpose(self): 3300 # Regression test for issue https://github.com/pytorch/pytorch/issues/118115 3301 # Testing continuity of all input tensors 3302 3303 def helper(shape, value): 3304 shape_t = shape[::-1] 3305 for i in range(2): 3306 for j in range(2): 3307 for k in range(2): 3308 x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t() 3309 y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t() 3310 z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t() 3311 3312 x_mps = x.detach().clone().to(device="mps") 3313 y_mps = y.detach().clone().to(device="mps") 3314 z_mps = z.detach().clone().to(device="mps") 3315 3316 result_cpu = x.addcdiv_(y, z, value=value) 3317 result_mps = x_mps.addcdiv(y_mps, z_mps, value=value) 3318 result_mps_out = result_cpu.detach().clone().to('mps') 3319 torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value) 3320 3321 self.assertEqual(result_cpu, result_mps) 3322 self.assertEqual(result_cpu, result_mps_out) 3323 3324 helper((2, 3), 1.0) 3325 helper((2, 3), 0.2) 3326 helper((100, 300), 1.0) 3327 helper((100, 300), 0.2) 3328 3329 def test_buffer_size_match(self): 3330 # this test shouldn't cause any crash 3331 size = 16 3332 cpu_A = torch.rand(size, device='cpu') 3333 cpu_F = torch.rand(size, size, size, device='cpu') 3334 3335 mps_A = cpu_A.to('mps') 3336 mps_F = cpu_F.to('mps') 3337 self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F) 3338 3339 def test_transpose_inplace(self): 3340 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 3341 cpu_x = torch.tensor(values, device='cpu') 3342 mps_x = torch.tensor(values, device='mps') 3343 3344 cpu_x.transpose_(0, 1) 3345 mps_x.transpose_(0, 1) 3346 self.assertEqual(cpu_x, mps_x.to('cpu')) 3347 3348 def test_expand_cpu_to_mps_copy(self): 3349 # https://github.com/pytorch/pytorch/issues/78642 3350 3351 x = torch.tensor(1).expand([10]).to("mps") 3352 x_cpu = torch.tensor(1).expand([10]) 3353 3354 self.assertEqual(x_cpu, x.cpu()) 3355 3356 def test_cpu_to_strided_mps_copy(self): 3357 # https://github.com/pytorch/pytorch/issues/86975 3358 3359 a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) 3360 b1 = torch.Tensor([-1, -1]) 3361 a1[1:, 1] = b1 3362 3363 a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps")) 3364 b2 = torch.Tensor([-1, -1]).to(torch.device("mps")) 3365 a2[1:, 1] = b2 3366 3367 self.assertEqual(a1, a2) 3368 3369 def test_view_slice_reshape(self): 3370 x = torch.randn([1, 4, 4], device="mps") 3371 y = x[0, :1, 1:] 3372 3373 x_cpu = x.to("cpu") 3374 y_cpu = x_cpu[0, :1, 1:] 3375 3376 r = y + 1 3377 r_cpu = y_cpu + 1 3378 self.assertEqual(r, r_cpu) 3379 3380 def test_slice_reshape(self): 3381 x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps") 3382 x_cpu = x.detach().clone().to("cpu") 3383 3384 x = x[:, 3:].view(2, 3, 4, 1) 3385 x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1) 3386 self.assertEqual(x, x_cpu) 3387 3388 x = x + 2 3389 x_cpu = x_cpu + 2 3390 self.assertEqual(x, x_cpu) 3391 3392 def test_reshape_storage_offset(self): 3393 # https://github.com/pytorch/pytorch/issues/95883 3394 B = 4 3395 T = 1 3396 3397 lin_cpu = nn.Linear(10, 256) 3398 lin_mps = nn.Linear(10, 256, device="mps") 3399 3400 # Use the same weights and bias as the ones from the cpu 3401 lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_() 3402 lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_() 3403 3404 x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True) 3405 x_cpu = x_mps.detach().clone().cpu().requires_grad_() 3406 x_mps = lin_mps(x_mps) 3407 x_cpu = lin_cpu(x_cpu) 3408 3409 self.assertEqual(x_mps.shape, (B, T, 256)) 3410 self.assertEqual(x_cpu.shape, (B, T, 256)) 3411 3412 cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1) 3413 cls_token_cpu = cls_token_mps.detach().clone().cpu() 3414 x_mps = torch.cat([cls_token_mps, x_mps], dim=1) 3415 x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1) 3416 3417 x_mps = x_mps.transpose(0, 1) 3418 x_cpu = x_cpu.transpose(0, 1) 3419 3420 target_mps = torch.rand_like(x_mps) 3421 target_cpu = target_mps.detach().clone().cpu() 3422 loss_mps = F.mse_loss(x_mps, target_mps) 3423 loss_cpu = F.mse_loss(x_cpu, target_cpu) 3424 self.assertEqual(loss_mps, loss_cpu) 3425 3426 loss_mps.backward() 3427 loss_cpu.backward() 3428 self.assertEqual(x_mps.grad, x_cpu.grad) 3429 3430 def test_stack_storage_offset(self): 3431 # https://github.com/pytorch/pytorch/issues/87856 3432 x_cpu = torch.tensor([[1, 2]]) 3433 x_mps = x_cpu.detach().clone().to("mps") 3434 3435 y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1) 3436 y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1) 3437 3438 self.assertEqual(y_cpu, y_mps) 3439 3440 t_mps = torch.tensor([1, 2, 3, 4], device="mps") 3441 t_cpu = t_mps.detach().cpu().detach() 3442 3443 x_mps = t_mps[2:] 3444 y_mps = t_mps[:2] 3445 3446 x_cpu = t_cpu[2:] 3447 y_cpu = t_cpu[:2] 3448 3449 res_mps = torch.stack((y_mps, x_mps), dim=-1) 3450 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1) 3451 3452 self.assertEqual(res_mps, res_cpu) 3453 3454 def test_unsafe_chunk(self): 3455 # https://github.com/pytorch/pytorch/issues/91065 3456 a = torch.rand(5, dtype=torch.float32, device="cpu") 3457 ret = a.unsafe_chunk(4, 0) 3458 y = ret[0] * ret[2] 3459 a_mps = a.to("mps") 3460 ret_mps = a_mps.unsafe_chunk(4, 0) 3461 y_mps = ret_mps[0] * ret_mps[2] 3462 self.assertEqual(y, y_mps) 3463 3464 def test_slice_casting(self): 3465 # generate random binary numbers 3466 cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8) 3467 mps_in = cpu_in.detach().clone().to("mps") 3468 # check copy_cast(unit8 -> bool) on tensors with storage offset 3469 cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool) 3470 mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool) 3471 self.assertEqual(cpu_out, mps_out) 3472 3473 def test_slice_reshape_contg_view(self): 3474 import torch 3475 3476 x_mps = torch.randn(1, 4800, 2, device="mps") 3477 x_cpu = x_mps.detach().clone().cpu() 3478 3479 r_mps = x_mps + 2 3480 r_cpu = x_cpu + 2 3481 3482 self.assertEqual(r_mps, r_cpu) 3483 3484 def test_contiguous_slice_2d(self): 3485 def helper(shape): 3486 for i in range(0, shape[0]): 3487 for j in range(0, shape[1]): 3488 t_mps = torch.randn(shape, device="mps") 3489 t_cpu = t_mps.detach().clone().cpu() 3490 3491 y_mps = t_mps[i:, :j] 3492 y_cpu = t_cpu[i:, :j] 3493 self.assertEqual(y_mps + 1, y_cpu + 1) 3494 3495 y_mps = t_mps[i:, j] 3496 y_cpu = t_cpu[i:, j] 3497 self.assertEqual(y_mps + 1, y_cpu + 1) 3498 3499 y_mps = t_mps[i, :j] 3500 y_cpu = t_cpu[i, :j] 3501 self.assertEqual(y_mps + 1, y_cpu + 1) 3502 3503 y_mps = t_mps[:i, :j] 3504 y_cpu = t_cpu[:i, :j] 3505 self.assertEqual(y_mps + 1, y_cpu + 1) 3506 3507 y_mps = t_mps[:i, j] 3508 y_cpu = t_cpu[:i, j] 3509 self.assertEqual(y_mps + 1, y_cpu + 1) 3510 3511 y_mps = t_mps[:i, j:] 3512 y_cpu = t_cpu[:i, j:] 3513 self.assertEqual(y_mps + 1, y_cpu + 1) 3514 3515 l = [] 3516 for N in range(1, 3): 3517 l.append(N) 3518 for C in range(1, 3): 3519 l.append(C) 3520 helper(l) 3521 for D in range(1, 3): 3522 l.append(D) 3523 helper(l) 3524 for H in range(1, 3): 3525 l.append(H) 3526 helper(l) 3527 for W in range(1, 3): 3528 l.append(W) 3529 helper(l) 3530 l.pop() 3531 l.pop() 3532 l.pop() 3533 l.pop() 3534 l.pop() 3535 3536 helper([9, 15, 4]) 3537 helper([9, 3, 2]) 3538 helper([3, 4, 18, 22]) 3539 helper([3, 4, 18, 22, 150]) 3540 3541 def test_contiguous_slice_3d(self): 3542 x = torch.randn(2, 3, 3, device="mps") 3543 x_cpu = x.detach().clone().cpu() 3544 x = x[:1] 3545 x_cpu = x_cpu[:1] 3546 out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2] 3547 out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2] 3548 self.assertEqual(out, out_cpu) 3549 3550 def test_view_slice(self): 3551 # https://github.com/pytorch/pytorch/issues/83995 3552 NUM_SAMPLES = 60 3553 s = (0, 1) 3554 3555 X = torch.rand(8000, 3, dtype=torch.float32, device='cpu') 3556 X_mps = X.detach().clone().to("cpu") 3557 3558 idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s)) 3559 pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1])) 3560 idx_mps = idx.to("mps") 3561 pts_mps = pts.to("mps") 3562 pts[:, s] = idx 3563 pts_mps[:, s] = idx_mps 3564 3565 actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float) 3566 actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps") 3567 3568 for i in range(NUM_SAMPLES): 3569 for j in range(X.shape[1]): 3570 actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j] 3571 actual_pts[i, j] = X[pts[i, j], j] 3572 self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j]) 3573 3574 def test_slice_scatter(self): 3575 shape = (4, 4) 3576 tensor = torch.randint(10, shape, device="mps") 3577 tensor_before = tensor.clone() 3578 torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor) 3579 torch.testing.assert_close(tensor, tensor_before) 3580 3581 def test_slice(self): 3582 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 3583 cpu_x = torch.tensor(values, device='cpu') 3584 mps_x = (torch.tensor(values, device='mps', dtype=torch.float)) 3585 3586 cpu_slice1 = cpu_x[:2, :] 3587 mps_slice1 = mps_x[:2, :] 3588 self.assertEqual(cpu_slice1, mps_slice1) 3589 3590 cpu_slice2 = cpu_x[:, :1] 3591 mps_slice2 = mps_x[:, :1] 3592 self.assertEqual(cpu_slice2, mps_slice2) 3593 3594 cpu_slice3 = cpu_x[1:2, :] 3595 mps_slice3 = mps_x[1:2, :] 3596 self.assertEqual(cpu_slice3, mps_slice3.to('cpu')) 3597 3598 cpu_slice4 = cpu_x[1, :] 3599 mps_slice4 = mps_x[1, :].to('cpu') 3600 self.assertEqual(cpu_slice4, mps_slice4) 3601 3602 @parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16]) 3603 def test_slice_view_api(self, torch_type: torch.dtype): 3604 3605 def helper(x_tensor, y_func, z_func, r_func=None): 3606 x_mps = x_tensor.detach().clone().to("mps") 3607 3608 y = y_func(x_tensor) 3609 y_mps = y_func(x_mps) 3610 self.assertEqual(y, y_mps) 3611 3612 z = z_func(y) 3613 z_mps = z_func(y_mps) 3614 self.assertEqual(z, z_mps) 3615 self.assertEqual(z.storage_offset(), z_mps.storage_offset()) 3616 3617 if r_func: 3618 r = r_func(z) 3619 r_mps = r_func(z_mps) 3620 self.assertEqual(r, r_mps) 3621 3622 # Skip bfloat16 before MacOS15 3623 if not (product_version < 15.0 and torch_type == torch.bfloat16): 3624 # Tests for previously encountered MPS bugs 3625 helper( 3626 torch.randn(4, 4, dtype=torch_type), 3627 lambda x: x[1], 3628 lambda y: y.reshape(2, 2), 3629 lambda z: z + 1 3630 ) 3631 helper( 3632 torch.randn(2, 4, dtype=torch_type), 3633 lambda x: x[1], 3634 lambda y: y + torch.ones(4, device=y.device) 3635 ) 3636 helper( 3637 torch.randn(4, 6, dtype=torch_type), 3638 lambda x: x[1], 3639 lambda y: y.reshape(3, 2).t(), 3640 lambda z: z + 1 3641 ) 3642 helper( 3643 torch.arange(4, dtype=torch_type).resize(1, 2, 2), 3644 lambda x: x.permute(2, 0, 1), 3645 lambda y: y + 1 3646 ) 3647 helper( 3648 torch.randn(4, 8, dtype=torch_type), 3649 lambda x: x.transpose(0, 1).reshape(-1), 3650 lambda y: y[:2], 3651 lambda z: z + 1 3652 ) 3653 helper( 3654 torch.randn(1, dtype=torch_type), 3655 lambda x: x.expand(2, 3), 3656 lambda y: y + torch.ones(2, 3, device=y.device) 3657 ) 3658 3659 def test_slice_reshape_contiguous(self): 3660 x = torch.randn(4, 4) 3661 x_mps = x.detach().clone().to("mps") 3662 3663 y = x[1] 3664 y_mps = x_mps[1] 3665 self.assertEqual(y, y_mps) 3666 3667 z = y.reshape(2, 2) 3668 z_mps = y_mps.reshape(2, 2) 3669 self.assertEqual(z, z_mps) 3670 self.assertEqual(z.storage_offset(), z_mps.storage_offset()) 3671 3672 def test_scalar_from_slice_unary(self): 3673 # https://github.com/pytorch/pytorch/issues/82543 3674 tensor_list = torch.tensor([1.0, 1.2], device="mps") 3675 3676 for scalar in tensor_list: 3677 r_mps = torch.ceil(scalar) 3678 r_cpu = torch.ceil(scalar.to("cpu")) 3679 self.assertEqual(r_mps.cpu(), r_cpu) 3680 3681 def test_scalar_from_slice_binary(self): 3682 # https://github.com/pytorch/pytorch/issues/82543 3683 def helper(binary_op): 3684 tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps") 3685 3686 for scalar in tensor_list: 3687 r_mps = binary_op(scalar, 1.0) 3688 r_cpu = binary_op(scalar.cpu(), 1.0) 3689 self.assertEqual(r_mps.cpu(), r_cpu) 3690 helper(torch.sub) 3691 helper(torch.add) 3692 helper(torch.not_equal) 3693 helper(torch.eq) 3694 3695 def test_slice_contiguous_view(self): 3696 # https://github.com/pytorch/pytorch/issues/77750 3697 3698 def helper(operator): 3699 t_mps = torch.tensor([1, 2, 3, 4], device="mps") 3700 t_cpu = torch.tensor([1, 2, 3, 4], device="cpu") 3701 3702 # contiguous view 3703 x_mps = t_mps[2:] # 3, 4 3704 y_mps = t_mps[:2] # 1, 2 3705 3706 x_cpu = t_cpu[2:] 3707 y_cpu = t_cpu[:2] 3708 3709 res_mps = res_cpu = None 3710 if operator == "<=": 3711 res_mps = x_mps <= y_mps 3712 res_cpu = x_cpu <= y_cpu 3713 elif operator == "<": 3714 res_mps = x_mps < y_mps 3715 res_cpu = x_cpu < y_cpu 3716 elif operator == ">=": 3717 res_mps = x_mps >= y_mps 3718 res_cpu = x_cpu >= y_cpu 3719 elif operator == ">": 3720 res_mps = x_mps >= y_mps 3721 res_cpu = x_cpu >= y_cpu 3722 elif operator == "==": 3723 res_mps = x_mps == y_mps 3724 res_cpu = x_cpu == y_cpu 3725 elif operator == "!=": 3726 res_mps = x_mps != y_mps 3727 res_cpu = x_cpu != y_cpu 3728 elif operator == "stack": 3729 res_mps = torch.stack((y_mps, x_mps), dim=-1) 3730 res_cpu = torch.stack((y_cpu, x_cpu), dim=-1) 3731 3732 self.assertEqual(res_mps, res_cpu) 3733 3734 for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]: 3735 helper(op) 3736 3737 def test_slice_of_slice(self): 3738 x = torch.tensor([0.5, 0.5], device="cpu") 3739 x_mps = torch.tensor([0.5, 0.5], device="mps") 3740 3741 tensor = x[1][None] 3742 tensor_mps = x_mps[1][None] 3743 3744 res = tensor.ne(0) 3745 res_mps = tensor_mps.ne(0) 3746 3747 self.assertEqual(res, res_mps) 3748 3749 def test_index_storage_offset(self): 3750 # https://github.com/pytorch/pytorch/issues/78107 3751 3752 a = torch.tensor([8.2670e-01, -1.0293e+00]) 3753 b_cpu = a[0] 3754 c_cpu = a[1] 3755 3756 # both 'b' and 'c' are views of 'a' 3757 # 'b' has a storage offset of 0, while 'c' has a storage offset of 1 3758 # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account, 3759 # otherwise it ends with same value as 'b' 3760 b = b_cpu.to('mps') 3761 c = c_cpu.to('mps') 3762 3763 res_mps = b > c 3764 res_cpu = b_cpu > c_cpu 3765 self.assertEqual(res_mps, res_cpu) 3766 3767 res_mps = c > b 3768 res_cpu = c_cpu > b_cpu 3769 self.assertEqual(res_mps, res_cpu) 3770 3771 def test_flatten(self): 3772 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 3773 cpu_x = torch.tensor(values, device='cpu') 3774 mps_x = torch.tensor(values, device='mps') 3775 3776 cpu_flatten1 = cpu_x.flatten() 3777 mps_flatten1 = mps_x.flatten().to('cpu') 3778 self.assertEqual(cpu_flatten1, mps_flatten1) 3779 3780 cpu_flatten2 = cpu_x.flatten(start_dim=1) 3781 mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu') 3782 self.assertEqual(cpu_flatten2, mps_flatten2) 3783 3784 cpu_flatten3 = cpu_x.flatten(end_dim=1) 3785 mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu') 3786 self.assertEqual(cpu_flatten3, mps_flatten3) 3787 3788 # Test repeat 3789 def test_repeat(self): 3790 def helper(shape, repeats): 3791 3792 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 3793 x = cpu_x.detach().clone().to('mps').requires_grad_() 3794 3795 y = x.repeat(repeats) 3796 ref_y = cpu_x.repeat(repeats) 3797 3798 cpu_grad = torch.randn(ref_y.shape) 3799 grad = cpu_grad.to('mps') 3800 3801 y.backward(gradient=grad) 3802 ref_y.backward(gradient=cpu_grad) 3803 3804 self.assertEqual(y, ref_y) 3805 self.assertEqual(x.grad, cpu_x.grad) 3806 3807 helper((2, 3, 4, 5), (2, 3, 4, 5)) 3808 helper((2, 3, 4), (4, 3, 2, 5, 7, 2)) 3809 helper((3, 4, 5), (2, 3, 4, 5)) 3810 helper((3, 4, 5), (2, 2, 2)) 3811 3812 def test_torch_repeat_interleave(self, device="mps"): 3813 y = torch.tensor([[1, 2], [3, 4]], device=device) 3814 # exercise single argument function signature 3815 temp = y.repeat_interleave(2) 3816 self.assertEqual(torch.Size([8]), temp.size()) 3817 3818 for dtype in [torch.int, torch.long]: 3819 lengths = torch.tensor([1, 2], dtype=dtype, device="mps") 3820 output_size = torch.sum(lengths) 3821 a = torch.repeat_interleave( 3822 y, 3823 lengths, 3824 dim=0, 3825 ) 3826 self.assertEqual(a.dtype, y.dtype) 3827 self.assertEqual(a.size(), torch.Size([3, 2])) 3828 3829 a_with_output = torch.repeat_interleave( 3830 y, 3831 lengths, 3832 dim=0, 3833 output_size=output_size, 3834 ) 3835 self.assertEqual(a_with_output.dtype, y.dtype) 3836 self.assertEqual(a_with_output.size(), torch.Size([3, 2])) 3837 3838 def test_repeat_interleave(self, device="mps"): 3839 x = torch.tensor([0, 1, 2, 3], device=device) 3840 expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device) 3841 # Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32 3842 self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=product_version >= 13.3) 3843 3844 with self.assertRaises(RuntimeError): 3845 torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2)) 3846 3847 with self.assertRaises(RuntimeError): 3848 torch.repeat_interleave(torch.arange(4.0, device=device)) 3849 3850 with self.assertRaises(RuntimeError): 3851 torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device)) 3852 3853 y = torch.tensor([[1, 2], [3, 4]], device=device) 3854 3855 y1_v1 = torch.repeat_interleave(y, 2) 3856 y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device)) 3857 y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device)) 3858 y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device) 3859 self.assertEqual(y1_v1, y1_expect) 3860 self.assertEqual(y1_v2, y1_expect) 3861 self.assertEqual(y1_v3, y1_expect) 3862 3863 y2 = torch.repeat_interleave(y, 3, dim=1) 3864 y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2], 3865 [3, 3, 3, 4, 4, 4]], device=device) 3866 self.assertEqual(y2, y2_expect) 3867 3868 y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0) 3869 y3_expect = torch.tensor([[1, 2], 3870 [3, 4], 3871 [3, 4]], device=device) 3872 self.assertEqual(y3, y3_expect) 3873 3874 with self.assertRaises(RuntimeError): 3875 torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0) 3876 3877 with self.assertRaises(RuntimeError): 3878 torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0) 3879 3880 # test zero sized dimension 3881 x = torch.zeros((5, 0), device=device) 3882 y = torch.repeat_interleave(x, repeats=3, dim=1) 3883 self.assertEqual(y, x.new_zeros(5, 0, device=device)) 3884 3885 x = torch.tensor([], dtype=torch.int64, device=device) 3886 y = torch.repeat_interleave(x, x) 3887 self.assertEqual(y, x) 3888 3889 def test_repeat_interleave_simple(self): 3890 def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None): 3891 x = torch.randn(shape, dtype=dtype, device="mps") 3892 x_cpu = x.detach().clone().cpu() 3893 3894 num_repeats_cpu = num_repeats.detach().clone().cpu() 3895 3896 repeats = torch.repeat_interleave(x, num_repeats, dim) 3897 repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim) 3898 3899 self.assertEqual(repeats, repeats_cpu) 3900 helper(shape=3, num_repeats=torch.tensor([100], device="mps")) 3901 helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0) 3902 helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0) 3903 helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1) 3904 helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2) 3905 3906 def test_count_nonzero(self): 3907 def helper(dtype): 3908 n = [ 3909 [[1, 0, 2], [3, 0, 2], [7, 9, -4]], 3910 [[0, 2, 3], [3, 2, 1], [2, 0, 0]], 3911 ] 3912 cpu_x = torch.tensor(n, dtype=dtype) 3913 mps_x = torch.tensor(n, dtype=dtype).to('mps') 3914 3915 # All non-zeros 3916 self.assertEqual( 3917 torch.count_nonzero(cpu_x), 3918 torch.count_nonzero(mps_x) 3919 ) 3920 3921 # dim=1 3922 self.assertEqual( 3923 torch.count_nonzero(cpu_x, dim=1), 3924 torch.count_nonzero(mps_x, dim=1) 3925 ) 3926 3927 # dim=(0, 1) 3928 self.assertEqual( 3929 torch.count_nonzero(cpu_x, dim=(0, 1)), 3930 torch.count_nonzero(mps_x, dim=(0, 1)) 3931 ) 3932 helper(torch.int32) 3933 helper(torch.int64) 3934 helper(torch.float16) 3935 helper(torch.float32) 3936 3937 def _test_module_empty_input(self, module, inp, check_size=True): 3938 inp.requires_grad_(True) 3939 out = module(inp) 3940 gO = torch.rand_like(out) 3941 out.backward(gO) 3942 if check_size: 3943 self.assertEqual(out.size(), inp.size()) 3944 for p in module.parameters(): 3945 if p.requires_grad: 3946 self.assertEqual(p.grad, torch.zeros_like(p.grad)) 3947 self.assertEqual(inp.grad, torch.zeros_like(inp)) 3948 3949 # Test dtype casting, with and without simultaneous device change 3950 def test_to(self): 3951 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 3952 cpu_x = torch.tensor(values, device='cpu') 3953 mps_x = torch.tensor(values, device='mps') 3954 3955 self.assertEqual(cpu_x.int(), mps_x.int().cpu()) 3956 self.assertEqual(cpu_x.bool(), mps_x.bool().cpu()) 3957 self.assertEqual(cpu_x.float(), mps_x.float().cpu()) 3958 3959 self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(), 3960 torch.tensor(1, dtype=torch.int32)) 3961 self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False)) 3962 self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True)) 3963 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(), 3964 torch.tensor(1, dtype=torch.int32)) 3965 self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(), 3966 torch.tensor(1.0)) 3967 self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int), 3968 torch.tensor(4, dtype=torch.int32)) 3969 self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(), 3970 torch.tensor(4, dtype=torch.int32)) 3971 self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int), 3972 torch.tensor(-8.34, device='cpu').to('mps').to(torch.int)) 3973 # Cast int8 and uint8 to float and compare results 3974 # See https://github.com/pytorch/pytorch/issues/80009 for more details 3975 cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8) 3976 cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8) 3977 for x_cpu in [cpu_byte, cpu_char]: 3978 x_mps = x_cpu.to('mps') 3979 self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32)) 3980 3981 3982 def test_setitem_scalar(self) -> None: 3983 device = 'mps' 3984 for dtype in [torch.int32, torch.float32, torch.int64]: 3985 for i in range(3, 6): 3986 for j in range(3, 6): 3987 t = torch.zeros(i, j, dtype=dtype, device=device) 3988 self.assertEqual(t.sum(), 0) 3989 t[1, 1] = 1 3990 t[2, 1] = j 3991 t[1, 2] = i 3992 self.assertEqual(t[1, 1], 1) 3993 self.assertEqual(t[1, 2], i) 3994 self.assertEqual(t[2, 1], j) 3995 self.assertEqual(t.sum(), 1 + i + j) 3996 3997 def test_stride_of_strides(self) -> None: 3998 x = torch.rand(32, 1, device='mps') 3999 y = x.as_strided(size=(32, 2), stride=(1, 0)) 4000 # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert 4001 # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435 4002 z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu") 4003 self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z) 4004 4005 def test_type_casting(self): 4006 # https://github.com/pytorch/pytorch/issues/81567 4007 def helper(data, to_dtype): 4008 a_cpu = torch.tensor(data) 4009 a_mps = a_cpu.to(torch.device('mps')) 4010 4011 res_cpu = a_cpu.type(to_dtype) 4012 res_mps = a_mps.type(to_dtype) 4013 self.assertEqual(res_cpu, res_mps) 4014 4015 helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor) 4016 helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor) 4017 helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor) 4018 helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor) 4019 helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor) 4020 helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor) 4021 helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor) 4022 4023 def test_to_casting(self): 4024 # https://github.com/pytorch/pytorch/issues/81567 4025 def helper(data, to_dtype): 4026 a_cpu = torch.tensor(data) 4027 a_mps = a_cpu.to(torch.device('mps')) 4028 4029 res_cpu = a_cpu.to(to_dtype) 4030 res_mps = a_mps.to(to_dtype) 4031 self.assertEqual(res_cpu, res_mps) 4032 4033 helper([9.0, 3.0, 5.0, 4.0], torch.int64) 4034 helper([9.0, 3.0, 5.0, 4.0], torch.float) 4035 helper([9.0, 3.0, 5.0, 4.0], torch.int32) 4036 helper([9.0, 3.0, 5.0, 4.0], torch.short) 4037 helper([9.0, 3.0, 5.0, 4.0], torch.half) 4038 helper([9.0, 3.0, 5.0, 4.0], torch.int8) 4039 helper([9.0, 3.0, 5.0, 4.0], torch.uint8) 4040 4041 def test_storage_offset_greater_than_src_nbytes(self): 4042 # https://github.com/pytorch/pytorch/issues/80844 4043 n_tensors = 100 4044 n_tensor_elems = 784 4045 elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32) 4046 4047 tensor_list = [] 4048 for i in range(0, n_tensors - 1): 4049 # create a list of contiguous view tensors (view tensor created by the slice op) 4050 t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)] 4051 tensor_list.append(t) 4052 4053 for i in range(0, n_tensors - 1): 4054 t = tensor_list[i].view(1, n_tensor_elems) 4055 t_mps = t.to("mps") 4056 self.assertEqual(t, t_mps.cpu(), f"i={i}") 4057 4058 # See https://github.com/pytorch/pytorch/issues/82427 4059 # and https://github.com/pytorch/pytorch/issues/83692 4060 def test_full_bugs(self): 4061 # Test should not crash 4062 x = torch.full((3, 3), True, device='mps') 4063 # torch.full should work for uint8 4064 y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8) 4065 y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8) 4066 self.assertEqual(y_mps, y_cpu) 4067 4068 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 4069 # See https://github.com/pytorch/pytorch/issues/84995 4070 def test_div_bugs(self): 4071 for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']): 4072 if dtype != torch.int64: 4073 x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype) 4074 y = torch.div(x, 101, rounding_mode=mode) 4075 self.assertEqual(y.sum(), 0) 4076 4077 # See https://github.com/pytorch/pytorch/issues/82663 4078 def test_bool_expand(self): 4079 x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps') 4080 y = torch.tensor([0, 1], dtype=torch.bool, device='mps') 4081 self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2))) 4082 4083 def test_int_expand(self): 4084 x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps') 4085 y = torch.tensor([0, 1], dtype=torch.int8, device='mps') 4086 self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2))) 4087 4088 # Empty unary op should return tensor of the same size 4089 def test_empty_neg(self): 4090 x = torch.tensor([[]], device='mps') 4091 y = -x 4092 self.assertEqual(x, y) 4093 4094 def _test_unique_scalar_empty(self, dtype, device, f): 4095 # test scalar 4096 x = torch.tensor(0, dtype=dtype, device=device) 4097 unique, inverse, counts = f(x, return_inverse=True, return_counts=True) 4098 expected_unique = torch.tensor([0], dtype=dtype, device=device) 4099 expected_inverse = torch.tensor(0, device=device) 4100 expected_counts = torch.tensor([1], device=device) 4101 self.assertEqual(unique, expected_unique) 4102 self.assertEqual(inverse, expected_inverse) 4103 self.assertEqual(counts, expected_counts) 4104 4105 # test zero sized tensor 4106 x = torch.zeros((0, 0, 3), dtype=dtype, device=device) 4107 unique, inverse, counts = f(x, return_inverse=True, return_counts=True) 4108 expected_unique = torch.tensor([], dtype=dtype, device=device) 4109 expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device) 4110 expected_counts = torch.tensor([], dtype=torch.long, device=device) 4111 self.assertEqual(unique, expected_unique) 4112 self.assertEqual(inverse, expected_inverse) 4113 self.assertEqual(counts, expected_counts) 4114 4115 def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape): 4116 def ensure_tuple(x): 4117 if isinstance(x, torch.Tensor): 4118 return (x,) 4119 return x 4120 4121 for return_inverse in [True, False]: 4122 for return_counts in [True, False]: 4123 # test with expected 4124 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) 4125 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) 4126 self.assertEqual(expected_unique, ret[0]) 4127 if return_inverse: 4128 self.assertEqual(expected_inverse, ret[1]) 4129 if return_counts: 4130 count_index = 1 + int(return_inverse) 4131 self.assertEqual(expected_counts, ret[count_index]) 4132 4133 # tests per-element unique on a higher rank tensor. 4134 y = x.view(additional_shape) 4135 y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True) 4136 self.assertEqual(expected_unique, y_unique) 4137 self.assertEqual(expected_inverse.view(additional_shape), y_inverse) 4138 self.assertEqual(expected_counts, y_counts) 4139 4140 def test_unique_all_dtypes(self, device="mps"): 4141 def helper(dtype): 4142 def ensure_tuple(x): 4143 if isinstance(x, torch.Tensor): 4144 return (x,) 4145 return x 4146 4147 if dtype is torch.bool: 4148 x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device) 4149 expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device) 4150 expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device) 4151 expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device) 4152 else: 4153 x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device) 4154 expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device) 4155 expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) 4156 expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) 4157 4158 # test sorted unique 4159 fs = ( 4160 lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs), 4161 lambda x, **kwargs: x.unique(sorted=True, **kwargs), 4162 ) 4163 x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x) 4164 xs = (x, x_sliced) 4165 for f, x in product(fs, xs): 4166 self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2)) 4167 self._test_unique_scalar_empty(dtype, device, f) 4168 4169 # test unsorted unique 4170 fs = ( 4171 lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs), 4172 lambda x, **kwargs: x.unique(sorted=False, **kwargs) 4173 ) 4174 for f, x in product(fs, xs): 4175 self._test_unique_scalar_empty(dtype, device, f) 4176 for return_inverse, return_counts in product((True, False), repeat=2): 4177 ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts)) 4178 self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts)) 4179 x_list = x.tolist() 4180 x_unique_list = ret[0].tolist() 4181 self.assertEqual(expected_unique.tolist(), sorted(x_unique_list)) 4182 if return_inverse: 4183 x_inverse_list = ret[1].tolist() 4184 for i, j in enumerate(x_inverse_list): 4185 self.assertEqual(x_list[i], x_unique_list[j]) 4186 if return_counts: 4187 count_index = 1 + int(return_inverse) 4188 x_counts_list = ret[count_index].tolist() 4189 for i, j in zip(x_unique_list, x_counts_list): 4190 count = 0 4191 for k in x_list: 4192 if k == i: 4193 count += 1 4194 self.assertEqual(j, count) 4195 [helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]] 4196 4197 def test_unique(self): 4198 def helper(x, return_inverse, return_counts): 4199 cpu_x = x 4200 x = cpu_x.detach().clone().to('mps') 4201 4202 result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts) 4203 result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts) 4204 4205 self.assertEqual(result, result_cpu) 4206 helper(torch.tensor([1, 2, 4, 2, 1]), False, False) 4207 helper(torch.randint(3, (10, )), False, False) 4208 helper(torch.randint(3, (10, )), True, False) 4209 helper(torch.randint(3, (10, )), False, True) 4210 helper(torch.randint(3, (10, )), True, True) 4211 helper(torch.randint(3, (1, )), True, True) 4212 helper(torch.randint(3, (0, )), True, True) 4213 # Regression test for https://github.com/pytorch/pytorch/issues/104879 4214 x = torch.arange(2, device="mps") 4215 self.assertEqual(x.reshape(1, 1, 2).unique(), x) 4216 4217 def test_unique_consecutive(self): 4218 def helper(x, dim, return_inverse, return_counts): 4219 cpu_x = x 4220 x = cpu_x.detach().clone().to('mps') 4221 4222 result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts) 4223 result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts) 4224 4225 self.assertEqual(result, result_cpu) 4226 helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False) 4227 helper(torch.randint(3, (10, )), 0, False, False) 4228 helper(torch.randint(3, (10, )), 0, True, False) 4229 helper(torch.randint(3, (10, )), 0, False, True) 4230 helper(torch.randint(3, (10, )), 0, True, True) 4231 helper(torch.randint(3, (10, )), 0, True, True) 4232 helper(torch.randint(3, (1, )), 0, True, True) 4233 helper(torch.randint(3, (0, )), 0, True, True) 4234 4235 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False) 4236 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True) 4237 helper(torch.randint(2, (20, 2)), 0, True, True) 4238 helper(torch.randint(2, (1, 2)), 0, True, True) 4239 helper(torch.randint(2, (0, 2)), 0, True, True) 4240 4241 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False) 4242 helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True) 4243 helper(torch.randint(2, (2, 20)), 1, True, True) 4244 helper(torch.randint(2, (2, 1)), 1, True, True) 4245 helper(torch.randint(2, (2, 0)), 1, True, True) 4246 4247 # See https://github.com/pytorch/pytorch/issues/85675 4248 def test_cat_non_contiguous(self): 4249 def rotate_subset(data, dim): 4250 x1 = data[:, :, :2, :] 4251 x2 = data[:, :, 2:, :] 4252 self.assertFalse(x1.is_contiguous()) 4253 self.assertFalse(x2.is_contiguous()) 4254 return torch.concat((x1, x2), dim=dim) 4255 for dtype in MPS_DTYPES: 4256 if dtype == torch.bool: 4257 continue 4258 data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6) 4259 data = data.to(memory_format=torch.channels_last) 4260 mps_data = data.to("mps") 4261 self.assertEqual(data, mps_data) 4262 for dim in range(data.dim()): 4263 cpu_result = rotate_subset(data, dim) 4264 mps_result = rotate_subset(mps_data, dim) 4265 self.assertEqual(cpu_result, mps_result.to("cpu")) 4266 # TODO: enable memory format test 4267 # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous()) 4268 4269 # See https://github.com/pytorch/pytorch/issues/85967 4270 def test_from_numpy_non_contiguous(self): 4271 a = np.arange(9).reshape(3, 3)[:, :2] 4272 t_cpu = torch.tensor(a, device="cpu") 4273 t_mps = torch.tensor(a, device="mps") 4274 self.assertEqual(t_cpu, t_mps.to("cpu")) 4275 4276 # See https://github.com/pytorch/pytorch/issues/86954 4277 def test_copy_non_contiguous(self): 4278 x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1) 4279 self.assertFalse(x.is_contiguous()) 4280 y = x.to('mps') 4281 self.assertFalse(y.is_contiguous()) 4282 self.assertEqual(x, y.to('cpu')) 4283 4284 x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2] 4285 y = x.to('mps') 4286 self.assertEqual(x, y.to('cpu')) 4287 4288 x = torch.full((4, 4, 4, 4), 13, device="cpu") 4289 y = torch.full((4, 4, 4, 4), 13, device="mps") 4290 z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2] 4291 x.permute(3, 2, 1, 0)[1::, ::2] = z 4292 # As y is on MPS and z on CPU, this dispatches to a copy operator 4293 y.permute(3, 2, 1, 0)[1::, ::2] = z 4294 self.assertEqual(x, y.to('cpu')) 4295 4296 # See https://github.com/pytorch/pytorch/issues/95417 4297 def test_copy_storage_offset(self): 4298 x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32) 4299 x_mps = torch.zeros(5, device="mps", dtype=torch.float32) 4300 update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64) 4301 update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64) 4302 x_cpu[2:4] = update_cpu 4303 x_mps[2:4] = update_mps # implicit type casting and copy 4304 self.assertEqual(x_cpu, x_mps) 4305 4306 x_cpu[2:4] = update_mps # implicit device moving and copy 4307 self.assertEqual(x_cpu, x_mps) 4308 4309 def test_copy_broadcasting(self): 4310 def helper(src_shape, dst_shape, src_dtype, dst_dtype): 4311 cpu_src = torch.randint(0, 127, src_shape).to(src_dtype) 4312 cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype) 4313 cpu_result = cpu_dst.copy_(cpu_src) 4314 mps_src = cpu_src.to("mps") 4315 mps_dst = cpu_dst.to("mps") 4316 mps_result = mps_dst.copy_(mps_src) 4317 self.assertEqual(cpu_result, mps_result) 4318 4319 test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8] 4320 4321 for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes): 4322 helper((2, 1), (2, 3), src_dtype, dst_dtype) 4323 helper((2, 1), (2, 2), src_dtype, dst_dtype) 4324 helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype) 4325 helper((3,), (2, 3), src_dtype, dst_dtype) 4326 helper((2,), (2, 2), src_dtype, dst_dtype) 4327 helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype) 4328 helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype) 4329 helper((1, 5), (4, 0, 5), src_dtype, dst_dtype) 4330 helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype) 4331 helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype) 4332 # Regression test for https://github.com/pytorch/pytorch/issues/107867 4333 self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0) 4334 4335 # See https://github.com/pytorch/pytorch/pull/84742 4336 # and https://github.com/pytorch/pytorch/pull/78319 4337 def test_binops_dtype_precedence(self): 4338 # Test dtype precedence (casting order) in binary operations by comparing to CPU result 4339 # Example values for all dtypes supported on the MPS backend 4340 sample_vals = { 4341 torch.bool: [False, True], 4342 torch.int16: [-15, 0, 1, 10], 4343 torch.int32: [-376, 0, 1, 13], 4344 torch.int64: [-8, 0, 1, 77], 4345 torch.float16: [-234.5, 0.0, 1.0, 2.0], 4346 torch.float32: [-1.0, 0.0, 0.1, 111.99], 4347 } 4348 # Test all combinations of dtypes, operations, dimensionality 4349 for dtype1, dtype2, binop in itertools.product( 4350 sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']): 4351 # bool minus bool is generally unsupported, so skip 4352 if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool): 4353 continue 4354 full_shape = (10,) 4355 for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]): 4356 # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})') 4357 # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4358 # (torch.tensor(val2, dtype=dtype2, device='mps'))) 4359 # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4360 # (torch.tensor(val2, dtype=dtype2, device='cpu'))) 4361 self.assertEqual( 4362 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4363 (torch.tensor(val2, dtype=dtype2, device='mps')), 4364 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4365 (torch.tensor(val2, dtype=dtype2, device='cpu'))) 4366 self.assertEqual( 4367 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop) 4368 (torch.tensor([val2], dtype=dtype2, device='mps')), 4369 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop) 4370 (torch.tensor([val2], dtype=dtype2, device='cpu'))) 4371 self.assertEqual( 4372 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4373 (torch.tensor([val2], dtype=dtype2, device='mps')), 4374 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4375 (torch.tensor([val2], dtype=dtype2, device='cpu'))) 4376 self.assertEqual( 4377 getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop) 4378 (torch.tensor(val2, dtype=dtype2, device='mps')), 4379 getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop) 4380 (torch.tensor(val2, dtype=dtype2, device='cpu'))) 4381 # Test tensors created with torch.full 4382 x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps') 4383 y1 = torch.tensor(val2, dtype=dtype2, device='mps') 4384 x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu') 4385 y2 = torch.tensor(val2, dtype=dtype2, device='cpu') 4386 self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2)) 4387 x3 = torch.tensor(val1, dtype=dtype1, device='mps') 4388 y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps') 4389 x4 = torch.tensor(val1, dtype=dtype1, device='cpu') 4390 y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu') 4391 self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4)) 4392 self.assertEqual( 4393 getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) 4394 (torch.full(full_shape, val2, dtype=dtype2, device='mps')), 4395 getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) 4396 (torch.full(full_shape, val2, dtype=dtype2, device='cpu'))) 4397 4398 def test_nansum(self): 4399 def helper(dtype, noncontiguous, dim): 4400 zero_cpu = torch.zeros((), dtype=dtype) 4401 4402 # Randomly scale the values 4403 scale = random.randint(10, 100) 4404 x_cpu: torch.Tensor = make_tensor( 4405 (5, 5), dtype=dtype, device='cpu', 4406 low=-scale, high=scale, noncontiguous=noncontiguous) 4407 4408 if dtype.is_floating_point: 4409 nan_mask_cpu = x_cpu < (0.2 * scale) 4410 x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu) 4411 x_cpu[nan_mask_cpu] = np.nan 4412 else: 4413 x_no_nan_cpu = x_cpu 4414 4415 x_mps = x_cpu.to('mps') 4416 actual_out_mps = torch.empty(0, dtype=dtype, device='mps') 4417 expect_out_cpu = torch.empty(0, dtype=dtype) 4418 dim_kwargs = {"dim": dim} if dim is not None else {} 4419 expect = torch.sum(x_no_nan_cpu, **dim_kwargs) 4420 4421 actual_cpu = torch.nansum(x_cpu, **dim_kwargs) 4422 # Sanity check on CPU 4423 self.assertEqual(expect, actual_cpu) 4424 4425 # Test MPS 4426 actual_mps = torch.nansum(x_mps, **dim_kwargs) 4427 # Test out= variant 4428 torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs) 4429 torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs) 4430 self.assertEqual(expect, actual_mps) 4431 self.assertEqual(expect_out_cpu, actual_out_mps) 4432 4433 args = itertools.product( 4434 (torch.float16, torch.float32, torch.int32, torch.int64), # dtype 4435 (True, False), # noncontiguous 4436 (0, 1, None), # dim 4437 ) 4438 4439 for dtype, noncontiguous, dim in args: 4440 with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim): 4441 helper(dtype, noncontiguous, dim) 4442 4443 def test_cumsum_all_dtypes(self): 4444 def helper(dtype): 4445 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype) 4446 t_cpu = torch.tensor([1, 1, 1, 1], device="cpu") 4447 4448 a = t.cumsum(0, dtype=dtype) 4449 a_cpu = t_cpu.cumsum(0, dtype=dtype) 4450 4451 self.assertEqual(a.cpu(), a_cpu) 4452 [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]] 4453 4454 try: 4455 helper(torch.int64) 4456 except Exception as e: 4457 e_string = str(e) 4458 self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." + 4459 " Support has been added in macOS 13.3") 4460 4461 def test_cumsum_bool(self): 4462 a = torch.ones(2**16, dtype=torch.bool) 4463 t_cpu = a.cumsum(0) 4464 t_mps = a.to("mps").cumsum(0) 4465 4466 self.assertEqual(t_cpu, t_mps) 4467 4468 def test_cumsum_minus_one_axis(self): 4469 def helper(dtype): 4470 # Test with axis -1 4471 cpu_x = None 4472 if dtype == torch.float32: 4473 cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) 4474 else: 4475 cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) 4476 x = cpu_x.detach().clone().to('mps') 4477 4478 cpu_y = cpu_x.cumsum(-1) 4479 y = x.cumsum(-1) 4480 4481 self.assertEqual(y, cpu_y) 4482 4483 [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]] 4484 4485 def test_cumprod_all_dtypes(self): 4486 def helper(dtype): 4487 t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype) 4488 t_cpu = torch.tensor([1, 1, 1, 1], device="cpu") 4489 4490 a = t.cumprod(0, dtype=dtype) 4491 a_cpu = t_cpu.cumprod(0, dtype=dtype) 4492 4493 self.assertEqual(a.cpu(), a_cpu) 4494 [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]] 4495 4496 try: 4497 helper(torch.int64) 4498 except Exception as e: 4499 e_string = str(e) 4500 self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input." 4501 + " Support has been added in macOS 13.3") 4502 4503 def test_cumprod_minus_one_axis(self): 4504 def helper(dtype): 4505 # Test with axis -1 4506 cpu_x = None 4507 if dtype == torch.float32: 4508 cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32) 4509 else: 4510 cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32) 4511 x = cpu_x.detach().clone().to('mps') 4512 4513 cpu_y = cpu_x.cumprod(-1) 4514 y = x.cumprod(-1) 4515 4516 self.assertEqual(y, cpu_y) 4517 4518 [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]] 4519 4520 def test_median_int16(self): 4521 def helper(shape, dtype): 4522 cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype) 4523 x = cpu_x.detach().clone().to('mps') 4524 4525 median_result = torch.median(x) 4526 median_result_cpu = torch.median(cpu_x) 4527 self.assertEqual(median_result, median_result_cpu) 4528 4529 helper((2, 8, 4, 5), torch.int16) 4530 4531 def test_activation_checkpoint_does_not_error(self): 4532 from torch.utils.checkpoint import checkpoint 4533 4534 for use_reentrant in (True, False): 4535 a = torch.tensor(1., device="mps", requires_grad=True) 4536 4537 def fn(x): 4538 return x.sin().cos().exp() 4539 4540 out = checkpoint(fn, a, use_reentrant=use_reentrant) 4541 out.backward() 4542 4543 def test_as_strided(self): 4544 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 4545 values_1 = [[1.0, 1.0], [1.0, 1.0]] 4546 cpu_x = torch.tensor(values, device='cpu') 4547 ones1 = torch.tensor(values_1, device='mps') 4548 x = cpu_x.detach().clone().to('mps').requires_grad_() 4549 strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2)) 4550 strided_mps = torch.as_strided(x, (2, 2), (1, 2)) 4551 self.assertEqual(strided_mps, strided_cpu) 4552 strided_cpu_out = strided_cpu + ones1.to('cpu') 4553 strided_mps_out = strided_mps + ones1 4554 self.assertEqual(strided_cpu_out, strided_mps_out) 4555 4556 # test with storage offsets 4557 cpu_x = torch.rand(3, 3, device='cpu') 4558 mps_x = cpu_x.to('mps') 4559 strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0) 4560 strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0) 4561 strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1) 4562 strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1) 4563 strided_cpu_out = strided_cpu1 - strided_cpu2 4564 strided_mps_out = strided_mps1 - strided_mps2 4565 self.assertEqual(strided_cpu_out, strided_mps_out) 4566 4567 def test_unfold(self): 4568 x = torch.arange(1., 8) 4569 x_mps = torch.arange(1., 8, device="mps") 4570 4571 y = x.unfold(0, 2, 1) 4572 y_mps = x_mps.unfold(0, 2, 1) 4573 4574 self.assertEqual(y, y_mps) 4575 4576 def test_unfold_all_devices_and_dtypes(self): 4577 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] 4578 for dt in supported_dtypes: 4579 x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps") 4580 self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) 4581 4582 def test_unfold_scalars(self): 4583 x = torch.tensor(0.5, device="mps") 4584 # unfold on a 0-dimensional tensor should always return a 1-d dimensional 4585 # tensor of shape [size] (i.e., the second parameter to unfold) 4586 4587 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1)) 4588 self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2)) 4589 self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1)) 4590 4591 def test_bincount_simple(self): 4592 input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps") 4593 input_cpu = input.to("cpu") 4594 weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32) 4595 weights_cpu = weights.to("cpu") 4596 4597 x = torch.bincount(input) 4598 x_cpu = torch.bincount(input_cpu) 4599 self.assertEqual(x, x_cpu) 4600 4601 y = input.bincount(weights) 4602 y_cpu = input_cpu.bincount(weights_cpu) 4603 self.assertEqual(y, y_cpu) 4604 4605 def test_bincount_reduction(self): 4606 device = "mps" 4607 # negative input throws 4608 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): 4609 torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32)) 4610 # n-d input, with n > 1 throws 4611 with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): 4612 torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) 4613 # minlength < 0 throws 4614 with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): 4615 torch.bincount(torch.tensor([1, 3], device=device), 4616 torch.tensor([.2, .2], device=device), 4617 minlength=-1) 4618 # n-d weights, with n > 1 throws 4619 with self.assertRaisesRegex(RuntimeError, '1-d'): 4620 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32), 4621 torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float)) 4622 # input and weights dim mismatch 4623 with self.assertRaisesRegex(RuntimeError, 'same length'): 4624 torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32), 4625 torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float)) 4626 # 1-d input with no elements and default minlength 4627 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), 4628 torch.zeros(0, dtype=torch.long, device=device)) 4629 # 1-d input with no elements and specified minlength 4630 self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), 4631 torch.zeros(10, dtype=torch.long, device=device)) 4632 4633 # test tensor method without weights 4634 long_counts = torch.tensor( 4635 [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() 4636 self.assertEqual( 4637 torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), 4638 long_counts) 4639 # test avoiding overflow for uint8 (#76979) 4640 count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount() 4641 count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount() 4642 self.assertEqual(count_uint8, count_int16) 4643 # test minlength functionality 4644 int_counts = torch.bincount( 4645 torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5) 4646 self.assertEqual( 4647 torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), 4648 int_counts) 4649 # test weights 4650 byte_counts = torch.bincount( 4651 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32), 4652 torch.tensor([.1, .2, .3, .4, .5], device=device)) 4653 self.assertEqual( 4654 torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) 4655 byte_counts = torch.bincount( 4656 torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32), 4657 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) 4658 self.assertEqual( 4659 torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts) 4660 # test non-contiguous inputs and weights 4661 inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32) 4662 weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) 4663 for i in [0, 1]: 4664 assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" 4665 assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" 4666 # inputs are non-contiguous but weights are contiguous 4667 self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) 4668 # inputs and weights are non-contiguous 4669 self.assertEqual( 4670 inputs[:, 1].bincount(weights[:, 1]), 4671 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) 4672 # weights are non-contiguous but inputs are contiguous 4673 self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), 4674 torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) 4675 4676 # test bincount on non-contiguous slices 4677 all0s = torch.zeros((32, 2), dtype=torch.int32, device=device) 4678 self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) 4679 4680 all1s = torch.ones((32, 2), dtype=torch.int32, device=device) 4681 self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) 4682 4683 # test large number of bins - global memory use 4684 big_exp = torch.zeros(100, device=device) 4685 big_exp[-1] = 50.0 4686 big_w = torch.tensor([.5] * 100, device=device) 4687 big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w) 4688 self.assertEqual(big_exp, big_out) 4689 # test large input size 4690 big_exp = torch.zeros(2, device=device, dtype=torch.int64) 4691 big_exp[1] = 10 4692 big_out = torch.ones(10, dtype=torch.int8, device=device).bincount() 4693 self.assertEqual(big_exp, big_out) 4694 4695 def test_bincount(self): 4696 device = "mps" 4697 input_size = (5000,) 4698 w = torch.randn(input_size, dtype=torch.float, device=device) 4699 w_cpu = w.cpu() 4700 4701 t = torch.randint(50, input_size, dtype=torch.int8, device=device) 4702 self.assertEqual(t.cpu().bincount(), t.bincount()) 4703 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 4704 4705 t = torch.randint(500, input_size, dtype=torch.int32, device=device) 4706 self.assertEqual(t.cpu().bincount(), t.bincount()) 4707 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 4708 4709 t = torch.randint(2000, input_size, dtype=torch.int32, device=device) 4710 self.assertEqual(t.cpu().bincount(), t.bincount()) 4711 self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) 4712 4713 t = torch.zeros([10], dtype=torch.int32, device=device) 4714 t[0] = 35488 4715 counted = t.bincount(minlength=65536) 4716 self.assertEqual(torch.sum(counted), 10) 4717 4718 def test_sum_backward(self): 4719 def helper(n, c): 4720 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 4721 cpu_x = torch.tensor(values, device='cpu', requires_grad=True) 4722 x = cpu_x.detach().clone().to('mps').requires_grad_() 4723 4724 all_sum = torch.sum(x) 4725 all_sum_cpu = torch.sum(cpu_x) 4726 4727 all_sum.backward() 4728 all_sum_cpu.backward() 4729 self.assertEqual(all_sum, all_sum_cpu) 4730 self.assertEqual(x.grad, cpu_x.grad) 4731 4732 helper(3, 3) 4733 4734 # L1 loss 4735 def test_l1_loss(self): 4736 def helper(shape, reduction): 4737 # create the criterion 4738 loss = torch.nn.L1Loss(reduction=reduction) 4739 4740 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 4741 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 4742 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 4743 targetMPS = targetCPU.detach().clone().to('mps') 4744 4745 # forward pass 4746 outputCPU = loss(inputCPU, targetCPU) 4747 outputMPS = loss(inputMPS, targetMPS) 4748 self.assertEqual(outputCPU, outputMPS) 4749 4750 # backward pass 4751 if reduction != 'none': 4752 # chose 2 just to make the grad_output > 1 in backward pass 4753 outputCPU.backward(gradient=torch.full_like(outputCPU, 2)) 4754 outputMPS.backward(gradient=torch.full_like(outputMPS, 2)) 4755 self.assertEqual(inputCPU.grad, inputMPS.grad) 4756 4757 helper([8, 5, 4], 'none') 4758 helper([7, 5, 2, 4], 'sum') 4759 # verify if changes in shape would cause cached graph lookup problems 4760 helper([7, 5, 2, 4, 6], 'sum') 4761 helper([8, 4, 5, 7, 6], 'mean') 4762 4763 # Mean Squared Error 4764 def test_mse_loss(self): 4765 def helper(shape, reduction): 4766 # create the criterion 4767 loss = torch.nn.MSELoss(reduction=reduction) 4768 4769 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 4770 targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 4771 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 4772 targetMPS = targetCPU.detach().clone().to('mps') 4773 4774 # forward pass 4775 outputCPU = loss(inputCPU, targetCPU) 4776 outputMPS = loss(inputMPS, targetMPS) 4777 self.assertEqual(outputCPU, outputMPS) 4778 4779 # backward pass 4780 if reduction != 'none': 4781 # chose 2 just to make the grad_output > 1 in backward pass 4782 outputCPU.backward(gradient=torch.full_like(outputCPU, 2)) 4783 outputMPS.backward(gradient=torch.full_like(outputMPS, 2)) 4784 self.assertEqual(inputCPU.grad, inputMPS.grad) 4785 4786 helper([8, 5, 4], 'none') 4787 helper([7, 5, 2, 4], 'sum') 4788 # verify if changes in shape would cause cached graph lookup problems 4789 helper([7, 5, 2, 4, 6], 'sum') 4790 helper([8, 4, 5, 7, 6], 'mean') 4791 4792 def test_mse_loss_strided_output(self): 4793 # https://github.com/pytorch/pytorch/issues/124621 4794 lf = nn.MSELoss(reduction='none') 4795 model_cpu = nn.Sequential( 4796 nn.Conv1d(3, 3, 1), 4797 ) 4798 model_mps = copy.deepcopy(model_cpu).to("mps") 4799 4800 x = torch.randn(128, 10, 3) 4801 x = x.permute(0, 2, 1) 4802 4803 x_mps = x.detach().clone().to("mps").permute(0, 2, 1) 4804 x_mps = x_mps.permute(0, 2, 1) 4805 4806 y = model_cpu(x) 4807 y_mps = model_mps(x_mps) 4808 4809 y = y.permute(0, 2, 1)[:, :5, :] 4810 y_mps = y_mps.permute(0, 2, 1)[:, :5, :] 4811 4812 y_hat = torch.randn(128, 5, 3) 4813 y_hat_mps = y_hat.detach().clone().to("mps") 4814 4815 loss = lf(y, y_hat) 4816 loss_mps = lf(y_mps, y_hat_mps) 4817 self.assertEqual(loss, loss_mps) 4818 4819 # Binary Cross Enropy 4820 def test_bce_loss_simple(self): 4821 def helper(shape, reduction): 4822 # create the criterion 4823 loss = torch.nn.BCELoss(reduction=reduction) 4824 4825 # input and target must be within [0..1] 4826 input_t = np.random.random_sample(size=shape).astype(np.float32) 4827 target_t = np.random.random_sample(size=shape).astype(np.float32) 4828 inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True) 4829 targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False) 4830 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 4831 targetMPS = targetCPU.detach().clone().to('mps') 4832 4833 # forward pass 4834 outputCPU = loss(inputCPU, targetCPU) 4835 outputMPS = loss(inputMPS, targetMPS) 4836 self.assertEqual(outputCPU, outputMPS) 4837 4838 # backward pass 4839 if reduction != 'none': 4840 # chose 0.6 just to have the grad_output != 1 4841 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6)) 4842 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6)) 4843 self.assertEqual(inputCPU.grad, inputMPS.grad) 4844 4845 helper([8, 5, 4], 'none') 4846 helper([7, 5, 2, 4], 'sum') 4847 # verify if changes in shape would cause cached graph lookup problems 4848 helper([7, 5, 2, 4, 6], 'sum') 4849 helper([8, 4, 5, 7, 6], 'mean') 4850 helper([1, 1, 32, 32], 'mean') 4851 4852 def test_bce_loss_always_nonnegative(self): 4853 target = torch.ones(5, device='mps') 4854 input = torch.ones(5, device='mps') 4855 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) 4856 4857 target = torch.zeros(5, device='mps') 4858 input = torch.zeros(5, device='mps') 4859 self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0) 4860 4861 def test_bce_loss_size_mismatch(self): 4862 bceloss = nn.BCELoss() 4863 a = torch.rand(25, device='mps') 4864 b = torch.rand(25, 1, device='mps') 4865 with self.assertRaisesRegex(ValueError, r'Using a target size \('): 4866 bceloss(a, b) 4867 4868 def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self): 4869 x_size = 1024 4870 y_size = 256 4871 target = torch.rand(x_size, y_size, device='mps') 4872 4873 for reduction in ['none', 'mean', 'sum']: 4874 output_sig = torch.rand(x_size, y_size, device='mps') - 0.5 4875 output_logits = output_sig.clone().detach() 4876 4877 output_sig.requires_grad = True 4878 output_logits.requires_grad = True 4879 weight = torch.rand(y_size, device='mps') 4880 4881 loss_sig = nn.BCELoss(weight, reduction=reduction)( 4882 torch.sigmoid(output_sig), target 4883 ) 4884 loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)( 4885 output_logits, target 4886 ) 4887 4888 self.assertEqual(loss_logits, loss_sig) 4889 4890 if reduction == 'none': 4891 grad = torch.rand(x_size, y_size, device='mps') 4892 loss_sig.backward(grad) 4893 loss_logits.backward(grad) 4894 else: 4895 loss_sig.backward() 4896 loss_logits.backward() 4897 4898 self.assertEqual(output_sig.grad, output_logits.grad) 4899 4900 def test_bce_with_logits_has_correct_grad_at_zero(self): 4901 output = torch.zeros(3, 1, requires_grad=True, device='mps') 4902 target = torch.zeros(3, 1, device='mps') 4903 nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward() 4904 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5) 4905 self.assertEqual(output.grad, expected_grad) 4906 4907 def test_bce_with_logits_broadcasts_weights(self): 4908 target = torch.rand(16, 4, device='mps') 4909 output = torch.rand(16, 4, device='mps') - 0.5 4910 4911 weight = torch.rand(4, device='mps') 4912 out1 = nn.BCEWithLogitsLoss(weight)(output, target) 4913 4914 weight = weight.expand(16, 4).contiguous() 4915 out2 = nn.BCEWithLogitsLoss(weight)(output, target) 4916 4917 self.assertEqual(out1, out2) 4918 4919 weight = torch.rand(16, 1, device='mps') 4920 out1 = nn.BCEWithLogitsLoss(weight)(output, target) 4921 4922 weight = weight.expand(16, 4).contiguous() 4923 out2 = nn.BCEWithLogitsLoss(weight)(output, target) 4924 4925 self.assertEqual(out1, out2) 4926 4927 def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self): 4928 target = torch.rand(64, 4, device='mps') 4929 output = torch.rand(64, 4, device='mps') - 0.5 4930 pos_weight = torch.ones(64, 4, device='mps') 4931 4932 self.assertEqual(nn.BCEWithLogitsLoss()(output, target), 4933 nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)) 4934 4935 def test_bce_with_logits_broadcasts_pos_weights(self): 4936 target = torch.rand(64, 4, device='mps') 4937 output = torch.rand(64, 4, device='mps') - 0.5 4938 pos_weight = torch.rand(4, device='mps') 4939 out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) 4940 4941 pos_weight1 = pos_weight.expand(1, 4) 4942 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target) 4943 4944 pos_weight2 = pos_weight.expand(64, 4) 4945 out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target) 4946 4947 self.assertEqual(out1, out2) 4948 self.assertEqual(out1, out3) 4949 4950 def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self): 4951 output = torch.zeros(3, 1, requires_grad=True, device='mps') 4952 target = torch.zeros(3, 1, device='mps') 4953 pos_weight = torch.ones(3, 1, device='mps') 4954 nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward() 4955 expected_grad = torch.empty(3, 1, device='mps').fill_(0.5) 4956 grad = output.grad 4957 self.assertEqual(grad, expected_grad) 4958 4959 def test_bce_with_logits_stability(self): 4960 output = torch.tensor([0., -120.], device='mps') 4961 target = torch.tensor([0., 1.], device='mps') 4962 pos_weight = torch.tensor([1., 1.], device='mps') 4963 4964 out1 = nn.BCEWithLogitsLoss()(output, target) 4965 self.assertTrue(torch.isfinite(out1).all().item()) 4966 4967 out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target) 4968 self.assertTrue(torch.isfinite(out2).all().item()) 4969 4970 def test_bce_loss_broadcasts_weights(self): 4971 sigmoid = nn.Sigmoid() 4972 target = torch.rand(16, 4, device='mps') 4973 output = torch.rand(16, 4, device='mps') - 0.5 4974 4975 weight = torch.rand(4, device='mps') 4976 out1 = nn.BCELoss(weight)(sigmoid(output), target) 4977 4978 weight = weight.expand(16, 4).contiguous() 4979 out2 = nn.BCELoss(weight)(sigmoid(output), target) 4980 4981 self.assertEqual(out1, out2) 4982 4983 weight = torch.rand(16, 1, device='mps') 4984 out1 = nn.BCELoss(weight)(sigmoid(output), target) 4985 4986 weight = weight.expand(16, 4).contiguous() 4987 out2 = nn.BCELoss(weight)(sigmoid(output), target) 4988 4989 self.assertEqual(out1, out2) 4990 4991 def test_cross_entropy_loss(self): 4992 # Regression test for https://github.com/pytorch/pytorch/issues/116095 4993 loss = nn.CrossEntropyLoss() 4994 pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps') 4995 target = torch.ones(3, dtype=torch.long, device='mps') 4996 output = loss(pred, target) 4997 output.backward() 4998 4999 def test_log_softmax(self): 5000 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 5001 cpu_x = torch.tensor(values, device='cpu', requires_grad=True) 5002 mps_x = torch.tensor(values, device='mps', requires_grad=True) 5003 5004 cpu_log_softmax = F.log_softmax(cpu_x, dim=0) 5005 mps_log_softmax = F.log_softmax(mps_x, dim=0) 5006 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu')) 5007 5008 cpu_grad = torch.ones_like(cpu_log_softmax) 5009 mps_grad = torch.ones_like(cpu_log_softmax).to('mps') 5010 5011 cpu_log_softmax.backward(gradient=cpu_grad) 5012 mps_log_softmax.backward(gradient=mps_grad) 5013 5014 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu')) 5015 5016 def test_log_softmax_large_numbers(self): 5017 values = [ 5018 [10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0], 5019 [-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0] 5020 ] 5021 cpu_x = torch.tensor(values, device='cpu', requires_grad=True) 5022 mps_x = torch.tensor(values, device='mps', requires_grad=True) 5023 5024 cpu_log_softmax = F.log_softmax(cpu_x, dim=-1) 5025 mps_log_softmax = F.log_softmax(mps_x, dim=-1) 5026 self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu')) 5027 5028 cpu_grad = torch.ones_like(cpu_log_softmax) 5029 mps_grad = torch.ones_like(cpu_log_softmax).to('mps') 5030 5031 cpu_log_softmax.backward(gradient=cpu_grad) 5032 mps_log_softmax.backward(gradient=mps_grad) 5033 5034 self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu')) 5035 5036 def test_eq(self): 5037 values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 5038 values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]] 5039 mps_x = torch.tensor(values1, device='mps') 5040 mps_y = torch.tensor(values2, device='mps') 5041 cpu_x = torch.tensor(values1, device='cpu') 5042 cpu_y = torch.tensor(values2, device='cpu') 5043 result_mps = torch.eq(mps_x, mps_y) 5044 result_cpu = torch.eq(cpu_x, cpu_y) 5045 5046 self.assertEqual(result_cpu, result_mps.to('cpu')) 5047 5048 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 5049 def test_signed_vs_unsigned_comparison(self): 5050 cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8) 5051 mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8) 5052 # in the comparison of signed vs. unsigned we should always cast to unsigned 5053 self.assertEqual(cpu_x == -1, mps_x == -1) 5054 self.assertEqual(cpu_x > -1, mps_x > -1) 5055 self.assertEqual(cpu_x < -1, mps_x < -1) 5056 5057 def test_eq_int64(self): 5058 values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] 5059 values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]] 5060 mps_x = torch.tensor(values1, device='mps') 5061 mps_y = torch.tensor(values2, device='mps') 5062 cpu_x = torch.tensor(values1, device='cpu') 5063 cpu_y = torch.tensor(values2, device='cpu') 5064 result_mps = torch.eq(mps_x, mps_y) 5065 result_cpu = torch.eq(cpu_x, cpu_y) 5066 5067 self.assertEqual(result_cpu, result_mps.to('cpu')) 5068 5069 def test_ne(self): 5070 def helper(shape): 5071 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5072 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5073 mps_x = cpu_x.detach().clone().to('mps') 5074 mps_y = cpu_y.detach().clone().to('mps') 5075 result_mps = torch.ne(mps_x, mps_y) 5076 result_cpu = torch.ne(cpu_x, cpu_y) 5077 5078 self.assertEqual(result_cpu, result_mps.to('cpu')) 5079 5080 helper((2, 3, 4, 5)) 5081 5082 def test_ne_scalar(self): 5083 def helper(shape): 5084 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5085 mps_x = cpu_x.detach().clone().to('mps') 5086 result_mps = torch.ne(mps_x, 0.0) 5087 result_cpu = torch.ne(cpu_x, 0.0) 5088 5089 self.assertEqual(result_cpu, result_mps.to('cpu')) 5090 5091 helper((2, 3, 4, 5)) 5092 5093 def test_lt(self): 5094 def helper(shape): 5095 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5096 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5097 mps_x = cpu_x.detach().clone().to('mps') 5098 mps_y = cpu_y.detach().clone().to('mps') 5099 result_mps = torch.lt(mps_x, mps_y) 5100 result_cpu = torch.lt(cpu_x, cpu_y) 5101 5102 self.assertEqual(result_cpu, result_mps.to('cpu')) 5103 5104 helper((2, 3, 4, 5)) 5105 5106 def test_lt_scalar(self): 5107 def helper(shape): 5108 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5109 mps_x = cpu_x.detach().clone().to('mps') 5110 result_mps = torch.lt(mps_x, 0.0) 5111 result_cpu = torch.lt(cpu_x, 0.0) 5112 5113 self.assertEqual(result_cpu, result_mps.to('cpu')) 5114 5115 helper((2, 3, 4, 5)) 5116 5117 def test_le(self): 5118 def helper(shape): 5119 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5120 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5121 mps_x = cpu_x.detach().clone().to('mps') 5122 mps_y = cpu_y.detach().clone().to('mps') 5123 result_mps = torch.le(mps_x, mps_y) 5124 result_cpu = torch.le(cpu_x, cpu_y) 5125 5126 self.assertEqual(result_cpu, result_mps.to('cpu')) 5127 5128 helper((2, 3, 4, 5)) 5129 5130 def test_le_scalar(self): 5131 def helper(shape): 5132 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5133 mps_x = cpu_x.detach().clone().to('mps') 5134 result_mps = torch.le(mps_x, 0.0) 5135 result_cpu = torch.le(cpu_x, 0.0) 5136 5137 self.assertEqual(result_cpu, result_mps.to('cpu')) 5138 5139 helper((2, 3, 4, 5)) 5140 5141 def test_ge(self): 5142 def helper(shape): 5143 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5144 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5145 mps_x = cpu_x.detach().clone().to('mps') 5146 mps_y = cpu_y.detach().clone().to('mps') 5147 result_mps = torch.ge(mps_x, mps_y) 5148 result_cpu = torch.ge(cpu_x, cpu_y) 5149 5150 self.assertEqual(result_cpu, result_mps.to('cpu')) 5151 5152 helper((2, 3, 4, 5)) 5153 5154 def test_ge_scalar(self): 5155 def helper(shape): 5156 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5157 mps_x = cpu_x.detach().clone().to('mps') 5158 result_mps = torch.ge(mps_x, 0.0) 5159 result_cpu = torch.ge(cpu_x, 0.0) 5160 5161 self.assertEqual(result_cpu, result_mps.to('cpu')) 5162 5163 helper((2, 3, 4, 5)) 5164 5165 def test_gt(self): 5166 def helper(shape): 5167 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5168 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float) 5169 mps_x = cpu_x.detach().clone().to('mps') 5170 mps_y = cpu_y.detach().clone().to('mps') 5171 result_mps = torch.gt(mps_x, mps_y) 5172 result_cpu = torch.gt(cpu_x, cpu_y) 5173 5174 self.assertEqual(result_cpu, result_mps.to('cpu')) 5175 5176 helper((2, 3, 4, 5)) 5177 5178 def test_gt_scalar(self): 5179 def helper(shape): 5180 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 5181 mps_x = cpu_x.detach().clone().to('mps') 5182 result_mps = torch.gt(mps_x, 0.0) 5183 result_cpu = torch.gt(cpu_x, 0.0) 5184 5185 self.assertEqual(result_cpu, result_mps.to('cpu')) 5186 5187 helper((2, 3, 4, 5)) 5188 5189 def test_argmax(self): 5190 # https://github.com/pytorch/pytorch/issues/98191 5191 cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]]) 5192 res_cpu = torch.argmax(cpu_tensor, dim=1) 5193 5194 mps_tensor = cpu_tensor.to(torch.device('mps')) 5195 res_mps = torch.argmax(mps_tensor, dim=1) 5196 self.assertEqual(res_cpu, res_mps) 5197 5198 # https://github.com/pytorch/pytorch/issues/92311 5199 mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32) 5200 cpu_tensor = mps_tensor.detach().clone().cpu() 5201 5202 res_mps = torch.argmax(mps_tensor, dim=1) 5203 res_cpu = torch.argmax(cpu_tensor, dim=1) 5204 self.assertEqual(res_cpu, res_mps) 5205 5206 # Test forward argmin argmax 5207 def test_argmin_argmax(self): 5208 def helper(n, c, h, w, reduction_type, dtype=torch.float32): 5209 if reduction_type == "max": 5210 arg_reduction_fn = torch.argmax 5211 else: 5212 arg_reduction_fn = torch.argmin 5213 5214 cpu_x = None 5215 x = None 5216 if (dtype not in [torch.float32, torch.bool]): 5217 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5218 x = cpu_x.detach().clone().to('mps') 5219 elif (dtype == torch.bool): 5220 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5221 x = cpu_x.detach().clone().to('mps') 5222 else: 5223 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True) 5224 x = cpu_x.detach().clone().to('mps').requires_grad_() 5225 5226 y = arg_reduction_fn(x) 5227 ref_y = arg_reduction_fn(cpu_x) 5228 self.assertEqual(y, ref_y) 5229 5230 y_0 = arg_reduction_fn(x, dim=0) 5231 refy_0 = arg_reduction_fn(cpu_x, dim=0) 5232 self.assertEqual(y_0, refy_0) 5233 5234 y_0dim = arg_reduction_fn(x, dim=0, keepdim=True) 5235 refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True) 5236 self.assertEqual(y_0dim, refy_0dim) 5237 5238 y_1 = arg_reduction_fn(x, dim=1) 5239 refy_1 = arg_reduction_fn(cpu_x, dim=1) 5240 self.assertEqual(y_1, refy_1) 5241 5242 y_1dim = arg_reduction_fn(x, dim=1, keepdim=True) 5243 refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True) 5244 self.assertEqual(y_1dim, refy_1dim) 5245 5246 y_2 = arg_reduction_fn(x, dim=2) 5247 refy_2 = arg_reduction_fn(cpu_x, dim=2) 5248 self.assertEqual(y_2, refy_2) 5249 5250 y_2dim = arg_reduction_fn(x, dim=2, keepdim=True) 5251 refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True) 5252 self.assertEqual(y_2dim, refy_2dim) 5253 5254 y_3 = arg_reduction_fn(x, dim=3) 5255 refy_3 = arg_reduction_fn(cpu_x, dim=3) 5256 self.assertEqual(y_3, refy_3) 5257 5258 y_3dim = arg_reduction_fn(x, dim=3, keepdim=True) 5259 refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True) 5260 self.assertEqual(y_3dim, refy_3dim) 5261 5262 helper(2, 8, 4, 4, "max", torch.float32) 5263 helper(2, 8, 4, 4, "max", torch.int32) 5264 helper(2, 8, 4, 4, "max", torch.float16) 5265 helper(2, 8, 4, 4, "max", torch.int64) 5266 helper(2, 8, 4, 4, "min", torch.float32) 5267 helper(2, 8, 4, 4, "min", torch.int32) 5268 helper(2, 8, 4, 4, "min", torch.float16) 5269 helper(2, 8, 4, 4, "min", torch.int64) 5270 5271 @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above") 5272 def test_reduction_sum_max_long_val(self): 5273 x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps") 5274 x_cpu = x_mps.detach().clone().cpu() 5275 5276 res_mps = torch.sum(x_mps) 5277 res_cpu = torch.sum(x_cpu) 5278 self.assertEqual(res_mps, res_cpu) 5279 5280 # Test forward max 5281 # Note - don't test grad now 5282 def test_max_el(self): 5283 def helper(n, c, h, w, dtype=torch.float32): 5284 5285 if (dtype not in [torch.float32, torch.bool]): 5286 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5287 x = cpu_x.detach().clone().to('mps') 5288 elif (dtype == torch.bool): 5289 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5290 x = cpu_x.detach().clone().to('mps') 5291 else: 5292 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True) 5293 x = cpu_x.detach().clone().to('mps') 5294 5295 ref_y = torch.max(cpu_x) 5296 y = torch.max(x) 5297 self.assertEqual(y, ref_y) 5298 5299 for dim in [0, 1, 2, 3]: 5300 for keepdim in [True, False]: 5301 y, idx = torch.max(x, dim=dim, keepdim=keepdim) 5302 refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim) 5303 self.assertEqual(y, refy) 5304 self.assertEqual(idx, refidx) 5305 5306 y_0 = torch.ones(c, h, w, device='mps', dtype=dtype) 5307 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64) 5308 torch.max(x, dim=0, out=(y_0, idx_0)) 5309 refy_0, refidx_0 = torch.max(cpu_x, dim=0) 5310 self.assertEqual(y_0, refy_0) 5311 self.assertEqual(idx_0, refidx_0) 5312 5313 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype) 5314 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64) 5315 torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim)) 5316 refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True) 5317 self.assertEqual(y_0dim, refy_0dim) 5318 self.assertEqual(idx_0dim, refidx_0dim) 5319 5320 y_1 = torch.ones(n, h, w, device='mps', dtype=dtype) 5321 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64) 5322 torch.max(x, dim=1, out=(y_1, idx_1)) 5323 refy_1, refidx_1 = torch.max(cpu_x, dim=1) 5324 self.assertEqual(y_1, refy_1) 5325 self.assertEqual(idx_1, refidx_1) 5326 5327 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype) 5328 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64) 5329 torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim)) 5330 refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1) 5331 self.assertEqual(y_1dim, refy_1dim) 5332 self.assertEqual(idx_1dim, refidx_1dim) 5333 5334 y_2 = torch.ones(n, c, w, device='mps', dtype=dtype) 5335 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64) 5336 torch.max(x, dim=2, out=(y_2, idx_2)) 5337 refy_2, refidx_2 = torch.max(cpu_x, dim=2) 5338 self.assertEqual(y_2, refy_2) 5339 self.assertEqual(idx_2, refidx_2) 5340 5341 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype) 5342 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64) 5343 torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim)) 5344 refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,) 5345 self.assertEqual(y_2dim, refy_2dim) 5346 self.assertEqual(idx_2dim, refidx_2dim) 5347 5348 y_3 = torch.ones(n, c, h, device='mps', dtype=dtype) 5349 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64) 5350 torch.max(x, dim=3, out=(y_3, idx_3)) 5351 refy_3, refidx_3 = torch.max(cpu_x, dim=3) 5352 self.assertEqual(y_3, refy_3) 5353 self.assertEqual(idx_3, refidx_3) 5354 5355 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype) 5356 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64) 5357 torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim)) 5358 refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,) 5359 self.assertEqual(y_3dim, refy_3dim) 5360 self.assertEqual(idx_3dim, refidx_3dim) 5361 5362 helper(2, 8, 4, 5, torch.float32) 5363 helper(2, 8, 4, 5, torch.int32) 5364 # helper(2, 8, 4, 5, torch.int64) 5365 5366 def test_median(self): 5367 def helper_dtype_int32(n1, n2, n3): 5368 cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32) 5369 mps_x = cpu_x.detach().clone().to('mps') 5370 5371 result_cpu = torch.median(cpu_x) 5372 result_mps = torch.median(mps_x) 5373 5374 self.assertEqual(result_cpu, result_mps) 5375 5376 for dim in [0, 1, 2]: 5377 for keepdim in [True, False]: 5378 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) 5379 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) 5380 self.assertEqual(y, refy) 5381 self.assertEqual(idx, refidx) 5382 5383 def helper_dtype_float32(n1, n2, n3): 5384 cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32) 5385 mps_x = cpu_x.detach().clone().to('mps') 5386 5387 result_cpu = torch.median(cpu_x) 5388 result_mps = torch.median(mps_x) 5389 5390 self.assertEqual(result_cpu, result_mps) 5391 5392 for dim in [0, 1, 2]: 5393 for keepdim in [True, False]: 5394 y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) 5395 refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) 5396 self.assertEqual(y, refy) 5397 self.assertEqual(idx, refidx) 5398 5399 helper_dtype_int32(10, 10, 10) # median at even place 5400 helper_dtype_int32(3, 3, 3) # median at odd place 5401 helper_dtype_int32(1, 1, 1) 5402 helper_dtype_int32(1, 2, 3) 5403 helper_dtype_float32(10, 10, 10) 5404 helper_dtype_float32(3, 3, 3) 5405 helper_dtype_float32(1, 1, 1) 5406 5407 def test_any(self): 5408 def helper(shape): 5409 input_xs = [] 5410 prod = 1 5411 5412 for i in range(len(shape)): 5413 prod *= shape[i] 5414 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape)) 5415 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape)) 5416 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape)) 5417 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape)) 5418 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape)) 5419 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape)) 5420 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape)) 5421 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool()) 5422 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool()) 5423 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool()) 5424 5425 for i, cpu_x in enumerate(input_xs): 5426 x = cpu_x.detach().clone().to('mps') 5427 y = torch.any(x) 5428 ref_y = torch.any(cpu_x) 5429 self.assertEqual(y, ref_y) 5430 5431 y_0 = torch.any(x, dim=0) 5432 refy_0 = torch.any(cpu_x, dim=0) 5433 self.assertEqual(y_0, refy_0) 5434 5435 y_0dim = torch.any(x, dim=0, keepdim=True) 5436 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True) 5437 self.assertEqual(y_0dim, refy_0dim) 5438 5439 y_0dim = torch.any(x, dim=0, keepdim=True) 5440 refy_0dim = torch.any(cpu_x, dim=0, keepdim=True) 5441 self.assertEqual(y_0dim, refy_0dim) 5442 5443 y_1 = torch.any(x, dim=1) 5444 refy_1 = torch.any(cpu_x, dim=1) 5445 self.assertEqual(y_1, refy_1) 5446 5447 y_1dim = torch.any(x, dim=1, keepdim=True) 5448 refy_1dim = torch.any(cpu_x, dim=1, keepdim=True) 5449 self.assertEqual(y_1dim, refy_1dim) 5450 5451 if (len(shape) > 2): 5452 y_2 = torch.any(x, dim=2) 5453 refy_2 = torch.any(cpu_x, dim=2) 5454 self.assertEqual(y_2, refy_2) 5455 5456 y_2dim = torch.any(x, dim=2, keepdim=True) 5457 refy_2dim = torch.any(cpu_x, dim=2, keepdim=True) 5458 self.assertEqual(y_2dim, refy_2dim) 5459 5460 y_3 = torch.any(x, dim=3) 5461 refy_3 = torch.any(cpu_x, dim=3) 5462 self.assertEqual(y_3, refy_3) 5463 5464 y_3dim = torch.any(x, dim=3, keepdim=True) 5465 refy_3dim = torch.any(cpu_x, dim=3, keepdim=True) 5466 self.assertEqual(y_3dim, refy_3dim) 5467 helper((1, 1, 1, 1)) 5468 helper((1, 1, 3, 3)) 5469 helper((7, 13)) 5470 helper((2, 8, 4, 5)) 5471 5472 def test_reduction_ops_5D(self): 5473 def helper(fn, dim): 5474 shape = (1, 1, 2, 1, 1) 5475 x_cpu = fn(torch.zeros(shape), dim=dim) 5476 x_mps = fn(torch.zeros(shape, device="mps"), dim=dim) 5477 self.assertEqual(x_cpu, x_mps.cpu()) 5478 for fn in [torch.any, torch.all]: 5479 for dim in range(0, 4): 5480 helper(fn, dim) 5481 5482 # 6D tensor reductions 5483 # Regression test for https://github.com/pytorch/pytorch/issues/95538 5484 x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu() 5485 self.assertEqual(x.all(), x.cpu().all()) 5486 for i in range(-5, 6): 5487 self.assertEqual(x.all(dim=i), x.cpu().all(dim=i)) 5488 5489 def test_all(self): 5490 def helper(shape): 5491 input_xs = [] 5492 prod = 1 5493 5494 for i in range(len(shape)): 5495 prod *= shape[i] 5496 input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape)) 5497 input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape)) 5498 input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape)) 5499 input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape)) 5500 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape)) 5501 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape)) 5502 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape)) 5503 input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool()) 5504 input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool()) 5505 input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool()) 5506 5507 for i, cpu_x in enumerate(input_xs): 5508 x = cpu_x.detach().clone().to('mps') 5509 y = torch.all(x) 5510 ref_y = torch.all(cpu_x) 5511 self.assertEqual(y, ref_y) 5512 5513 y_0 = torch.all(x, dim=0) 5514 refy_0 = torch.all(cpu_x, dim=0) 5515 self.assertEqual(y_0, refy_0) 5516 5517 y_0dim = torch.all(x, dim=0, keepdim=True) 5518 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True) 5519 self.assertEqual(y_0dim, refy_0dim) 5520 5521 y_0dim = torch.all(x, dim=0, keepdim=True) 5522 refy_0dim = torch.all(cpu_x, dim=0, keepdim=True) 5523 self.assertEqual(y_0dim, refy_0dim) 5524 5525 y_1 = torch.all(x, dim=1) 5526 refy_1 = torch.all(cpu_x, dim=1) 5527 self.assertEqual(y_1, refy_1) 5528 5529 y_1dim = torch.all(x, dim=1, keepdim=True) 5530 refy_1dim = torch.all(cpu_x, dim=1, keepdim=True) 5531 self.assertEqual(y_1dim, refy_1dim) 5532 if (len(shape) > 2): 5533 y_2 = torch.all(x, dim=2) 5534 refy_2 = torch.all(cpu_x, dim=2) 5535 self.assertEqual(y_2, refy_2) 5536 5537 y_2dim = torch.all(x, dim=2, keepdim=True) 5538 refy_2dim = torch.all(cpu_x, dim=2, keepdim=True) 5539 self.assertEqual(y_2dim, refy_2dim) 5540 5541 y_3 = torch.all(x, dim=3) 5542 refy_3 = torch.all(cpu_x, dim=3) 5543 self.assertEqual(y_3, refy_3) 5544 5545 y_3dim = torch.all(x, dim=3, keepdim=True) 5546 refy_3dim = torch.all(cpu_x, dim=3, keepdim=True) 5547 self.assertEqual(y_3dim, refy_3dim) 5548 5549 helper((1, 1, 1, 1)) 5550 helper((1, 1, 3, 3)) 5551 helper((7, 13)) 5552 helper((2, 8, 4, 5)) 5553 # Empty tensor 5554 x_cpu = torch.tensor([], dtype=torch.bool) 5555 x_mps = x_cpu.to("mps") 5556 self.assertEqual(x_cpu.all(), x_mps.all().cpu()) 5557 5558 # Test forward min 5559 def test_min_el(self): 5560 def helper(n, c, h, w): 5561 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 5562 x = cpu_x.detach().clone().to('mps') 5563 5564 y = torch.min(x) 5565 ref_y = torch.min(cpu_x) 5566 self.assertEqual(y, ref_y) 5567 5568 y_0, idx_0 = torch.min(x, dim=0) 5569 refy_0, refidx_0 = torch.min(cpu_x, dim=0) 5570 self.assertEqual(y_0, refy_0) 5571 self.assertEqual(idx_0, refidx_0) 5572 5573 y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float) 5574 idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64) 5575 torch.min(x, dim=0, out=(y_0, idx_0)) 5576 refy_0, refidx_0 = torch.min(cpu_x, dim=0) 5577 self.assertEqual(y_0, refy_0) 5578 self.assertEqual(idx_0, refidx_0) 5579 5580 y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True) 5581 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True) 5582 self.assertEqual(y_0dim, refy_0dim) 5583 self.assertEqual(idx_0dim, refidx_0dim) 5584 5585 y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float) 5586 idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64) 5587 torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim)) 5588 refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True) 5589 self.assertEqual(y_0dim, refy_0dim) 5590 self.assertEqual(idx_0dim, refidx_0dim) 5591 5592 y_1, idx_1 = torch.min(x, dim=1) 5593 refy_1, refidx_1 = torch.min(cpu_x, dim=1) 5594 self.assertEqual(y_1, refy_1) 5595 self.assertEqual(idx_1, refidx_1) 5596 5597 y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float) 5598 idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64) 5599 torch.min(x, dim=1, out=(y_1, idx_1)) 5600 refy_1, refidx_1 = torch.min(cpu_x, dim=1) 5601 self.assertEqual(y_1, refy_1) 5602 self.assertEqual(idx_1, refidx_1) 5603 5604 y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True) 5605 refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True) 5606 self.assertEqual(y_1dim, refy_1dim) 5607 self.assertEqual(idx_1dim, refidx_1dim) 5608 5609 y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float) 5610 idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64) 5611 torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim)) 5612 refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1) 5613 self.assertEqual(y_1dim, refy_1dim) 5614 self.assertEqual(idx_1dim, refidx_1dim) 5615 5616 y_2, idx_2 = torch.min(x, dim=2) 5617 refy_2, refidx_2 = torch.min(cpu_x, dim=2) 5618 self.assertEqual(y_2, refy_2) 5619 self.assertEqual(idx_2, refidx_2) 5620 5621 y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float) 5622 idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64) 5623 torch.min(x, dim=2, out=(y_2, idx_2)) 5624 refy_2, refidx_2 = torch.min(cpu_x, dim=2) 5625 self.assertEqual(y_2, refy_2) 5626 self.assertEqual(idx_2, refidx_2) 5627 5628 y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True) 5629 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True) 5630 self.assertEqual(y_2dim, refy_2dim) 5631 self.assertEqual(idx_2dim, refidx_2dim) 5632 5633 y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float) 5634 idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64) 5635 torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim)) 5636 refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,) 5637 self.assertEqual(y_2dim, refy_2dim) 5638 self.assertEqual(idx_2dim, refidx_2dim) 5639 5640 y_3, idx_3 = torch.min(x, dim=3) 5641 refy_3, refidx_3 = torch.min(cpu_x, dim=3) 5642 self.assertEqual(y_3, refy_3) 5643 self.assertEqual(idx_3, refidx_3) 5644 5645 y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float) 5646 idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64) 5647 torch.min(x, dim=3, out=(y_3, idx_3)) 5648 refy_3, refidx_3 = torch.min(cpu_x, dim=3) 5649 self.assertEqual(y_3, refy_3) 5650 self.assertEqual(idx_3, refidx_3) 5651 5652 y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True) 5653 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True) 5654 self.assertEqual(y_3dim, refy_3dim) 5655 self.assertEqual(idx_3dim, refidx_3dim) 5656 5657 y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float) 5658 idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64) 5659 torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim)) 5660 refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,) 5661 self.assertEqual(y_3dim, refy_3dim) 5662 self.assertEqual(idx_3dim, refidx_3dim) 5663 5664 helper(2, 8, 4, 5) 5665 5666 # Test forward sum 5667 def test_sum(self): 5668 def helper(n, c, h, w, dtype=torch.float32): 5669 cpu_x = None 5670 x = None 5671 if (dtype not in [torch.float32, torch.bool]): 5672 cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5673 x = cpu_x.detach().clone().to('mps') 5674 elif (dtype == torch.bool): 5675 cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False) 5676 x = cpu_x.detach().clone().to('mps') 5677 else: 5678 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True) 5679 x = cpu_x.detach().clone().to('mps').requires_grad_() 5680 5681 all_sum = torch.sum(x) 5682 all_sum_cpu = torch.sum(cpu_x) 5683 5684 self.assertEqual(all_sum, all_sum_cpu) 5685 5686 nil_dim_sum = torch.sum(x, dim=[]) 5687 nil_dim_sum_cpu = torch.sum(cpu_x, dim=[]) 5688 5689 self.assertEqual(nil_dim_sum, nil_dim_sum_cpu) 5690 5691 nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True) 5692 nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True) 5693 5694 self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim) 5695 5696 zero_dim_sum = torch.sum(x, dim=[0]) 5697 zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0]) 5698 5699 self.assertEqual(zero_dim_sum, zero_dim_sum_cpu) 5700 5701 zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True) 5702 zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True) 5703 5704 self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim) 5705 5706 zero_one_dim_sum = torch.sum(x, dim=[0, 1]) 5707 zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1]) 5708 5709 self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu) 5710 5711 zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True) 5712 zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True) 5713 5714 self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim) 5715 5716 two_three_dim_sum = torch.sum(x, dim=[2, 3]) 5717 two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3]) 5718 5719 self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu) 5720 5721 two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True) 5722 two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True) 5723 5724 self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu) 5725 5726 helper(2, 8, 4, 5) 5727 helper(2, 8, 4, 5, dtype=torch.int32) 5728 helper(2, 8, 4, 5, dtype=torch.int64) 5729 helper(2, 8, 4, 5, dtype=torch.bool) 5730 # Regression test for https://github.com/pytorch/pytorch/issues/136132 5731 x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2) 5732 self.assertEqual(x.numel(), 8) 5733 self.assertEqual(x.max().item(), 30.0) 5734 5735 # Test forward prod 5736 def test_prod(self): 5737 def helper(shape, dtype=torch.float32): 5738 cpu_x = None 5739 x = None 5740 if (dtype not in [torch.float32, torch.bool]): 5741 cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False) 5742 x = cpu_x.detach().clone().to('mps') 5743 elif (dtype == torch.bool): 5744 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 5745 x = cpu_x.detach().clone().to('mps') 5746 else: 5747 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 5748 x = cpu_x.detach().clone().to('mps').requires_grad_() 5749 5750 all_prod = torch.prod(x) 5751 all_prod_cpu = torch.prod(cpu_x) 5752 5753 self.assertEqual(all_prod, all_prod_cpu) 5754 5755 for dim in range(len(shape)): 5756 dim_prod = torch.prod(x, dim=dim) 5757 dim_prod_cpu = torch.prod(cpu_x, dim=dim) 5758 5759 self.assertEqual(dim_prod, dim_prod_cpu) 5760 5761 dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True) 5762 dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True) 5763 5764 self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim) 5765 5766 for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]: 5767 helper((2, 3), dtype) 5768 5769 # Test forward mean 5770 def test_mean(self): 5771 def helper(n, c, h, w): 5772 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True) 5773 x = cpu_x.detach().clone().to('mps').requires_grad_() 5774 5775 all_mean = torch.mean(x) 5776 all_mean_cpu = torch.mean(cpu_x) 5777 5778 self.assertEqual(all_mean, all_mean_cpu) 5779 5780 nil_dim_mean = torch.mean(x, dim=[]) 5781 nil_dim_mean_cpu = torch.mean(cpu_x, dim=[]) 5782 5783 self.assertEqual(nil_dim_mean, nil_dim_mean_cpu) 5784 5785 nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True) 5786 nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True) 5787 5788 self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim) 5789 5790 zero_dim_mean = torch.mean(x, dim=[0]) 5791 zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0]) 5792 5793 self.assertEqual(zero_dim_mean, zero_dim_mean_cpu) 5794 5795 zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True) 5796 zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True) 5797 5798 self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim) 5799 5800 zero_one_dim_mean = torch.mean(x, dim=[0, 1]) 5801 zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1]) 5802 5803 self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu) 5804 5805 zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True) 5806 zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True) 5807 5808 self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim) 5809 5810 two_three_dim_mean = torch.mean(x, dim=[2, 3]) 5811 two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3]) 5812 5813 self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu) 5814 5815 two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True) 5816 two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True) 5817 5818 self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu) 5819 5820 helper(2, 8, 4, 5) 5821 5822 # Test std 5823 def test_std(self): 5824 def helper(shape): 5825 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 5826 x = cpu_x.detach().clone().to('mps') 5827 5828 all_std = torch.std(x, unbiased=False) 5829 all_std_cpu = torch.std(cpu_x, unbiased=False) 5830 5831 self.assertEqual(all_std, all_std_cpu) 5832 5833 nil_dim_std = torch.std(x, dim=[], unbiased=False) 5834 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False) 5835 5836 self.assertEqual(nil_dim_std, nil_dim_std_cpu) 5837 5838 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False) 5839 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False) 5840 5841 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim) 5842 5843 zero_dim_std = torch.std(x, dim=[0], unbiased=False) 5844 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False) 5845 5846 self.assertEqual(zero_dim_std, zero_dim_std_cpu) 5847 5848 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False) 5849 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False) 5850 5851 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim) 5852 5853 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False) 5854 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False) 5855 5856 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu) 5857 5858 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False) 5859 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False) 5860 5861 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim) 5862 5863 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False) 5864 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False) 5865 5866 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu) 5867 5868 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False) 5869 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False) 5870 5871 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu) 5872 5873 all_std = torch.std(x, unbiased=True) 5874 all_std_cpu = torch.std(cpu_x, unbiased=True) 5875 5876 self.assertEqual(all_std, all_std_cpu) 5877 5878 nil_dim_std = torch.std(x, dim=[], unbiased=True) 5879 nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True) 5880 5881 self.assertEqual(nil_dim_std, nil_dim_std_cpu) 5882 5883 nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True) 5884 nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True) 5885 5886 self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim) 5887 5888 zero_dim_std = torch.std(x, dim=[0], unbiased=True) 5889 zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True) 5890 5891 self.assertEqual(zero_dim_std, zero_dim_std_cpu) 5892 5893 zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True) 5894 zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True) 5895 5896 self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim) 5897 5898 zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True) 5899 zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True) 5900 5901 self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu) 5902 5903 zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True) 5904 zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True) 5905 5906 self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim) 5907 5908 two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True) 5909 two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True) 5910 5911 self.assertEqual(two_three_dim_std, two_three_dim_std_cpu) 5912 5913 two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True) 5914 two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True) 5915 5916 self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu) 5917 5918 helper((4, 5, 6, 7)) 5919 # verify if a change in shape of input would cause problems with graph caching 5920 helper((9, 5, 6, 7)) 5921 5922 # Test var 5923 def test_var_simple(self): 5924 def helper(): 5925 5926 shape = [2, 3, 4, 5] 5927 5928 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 5929 x = cpu_x.detach().clone().to('mps') 5930 5931 for unbiased in [False, True]: 5932 for keepdim in [False, True]: 5933 5934 zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased) 5935 zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased) 5936 5937 self.assertEqual(zero_dim_var, zero_dim_var_cpu) 5938 5939 all_var = torch.var(x, unbiased=unbiased) 5940 all_var_cpu = torch.var(cpu_x, unbiased=unbiased) 5941 5942 self.assertEqual(all_var, all_var_cpu) 5943 5944 nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased) 5945 nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased) 5946 5947 self.assertEqual(nil_dim_var, nil_dim_var_cpu) 5948 5949 zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased) 5950 zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased) 5951 5952 self.assertEqual(zero_dim_var, zero_dim_var_cpu) 5953 5954 zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased) 5955 zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased) 5956 5957 self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu) 5958 5959 two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased) 5960 two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased) 5961 5962 self.assertEqual(two_three_dim_var, two_three_dim_var_cpu) 5963 5964 helper() 5965 5966 # Test forward amax 5967 def test_amax(self): 5968 def helper(shape, dim, keepdim): 5969 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 5970 x = cpu_x.detach().clone().to('mps').requires_grad_() 5971 5972 result = torch.amax(x, dim=dim, keepdim=keepdim) 5973 result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim) 5974 5975 cpu_grad = torch.randn(result_cpu.shape) 5976 grad = cpu_grad.to('mps') 5977 5978 result_cpu.backward(gradient=cpu_grad) 5979 result.backward(gradient=grad) 5980 5981 self.assertEqual(result, result_cpu) 5982 self.assertEqual(x.grad, cpu_x.grad) 5983 5984 for dim in ([], [0], [0, 1], [2, 3]): 5985 for keepdim in [False, True]: 5986 helper((2, 8, 4, 5), dim, keepdim) 5987 5988 # Test forward amin 5989 def test_amin(self): 5990 def helper(shape, dim, keepdim): 5991 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 5992 x = cpu_x.detach().clone().to('mps').requires_grad_() 5993 5994 result = torch.amin(x, dim=dim, keepdim=keepdim) 5995 result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim) 5996 5997 cpu_grad = torch.randn(result_cpu.shape) 5998 grad = cpu_grad.to('mps') 5999 6000 result_cpu.backward(gradient=cpu_grad) 6001 result.backward(gradient=grad) 6002 6003 self.assertEqual(result, result_cpu) 6004 self.assertEqual(x.grad, cpu_x.grad) 6005 6006 for dim in ([], [0], [0, 1], [2, 3]): 6007 for keepdim in [False, True]: 6008 helper((2, 8, 4, 5), dim, keepdim) 6009 6010 # Test minimum and maximum 6011 def test_minimum_maximum(self): 6012 def helper(n, c, h, w): 6013 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6014 cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6015 mps_x = cpu_x.detach().clone().to('mps') 6016 mps_y = cpu_y.detach().clone().to('mps') 6017 6018 minimum_result_cpu = torch.minimum(cpu_x, cpu_y) 6019 minimum_result_mps = torch.minimum(mps_x, mps_y) 6020 self.assertEqual(minimum_result_cpu, minimum_result_mps) 6021 6022 maximum_result_cpu = torch.maximum(cpu_x, cpu_y) 6023 maximum_result_mps = torch.maximum(mps_x, mps_y) 6024 self.assertEqual(maximum_result_cpu, maximum_result_mps) 6025 6026 helper(1, 1, 4, 5) 6027 6028 def test_clamp_fp16_fp32(self): 6029 cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False) 6030 x = cpu_x.detach().clone().to('mps') 6031 6032 dtype = torch.float16 6033 6034 clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16) 6035 clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10 6036 clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps) 6037 6038 clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) 6039 clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10 6040 clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu) 6041 6042 self.assertEqual(clamp_result_mps, clamp_result_cpu) 6043 6044 def test_clamp_nan(self): 6045 t_mps = torch.tensor([torch.nan, 1, 2], device="mps") 6046 t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu") 6047 6048 clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100) 6049 clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100) 6050 6051 self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu) 6052 6053 clamp_min_mps = torch.clamp(t_mps, min=-100) 6054 clamp_min_cpu = torch.clamp(t_cpu, min=-100) 6055 6056 self.assertEqual(clamp_min_mps, clamp_min_cpu) 6057 6058 clamp_max_mps = torch.clamp(t_mps, max=100) 6059 clamp_max_cpu = torch.clamp(t_cpu, max=100) 6060 6061 self.assertEqual(clamp_max_mps, clamp_max_cpu) 6062 6063 # Test clamp_min 6064 def test_clamp_min(self): 6065 def helper(n, c, h, w): 6066 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6067 x = cpu_x.detach().clone().to('mps') 6068 6069 cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6070 min_t = cpu_min_t.detach().clone().to('mps') 6071 6072 clamp_min_result = torch.clamp_min(x, min=5.0) 6073 clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0) 6074 6075 self.assertEqual(clamp_min_result, clamp_min_result_cpu) 6076 6077 clamp_min_t_result = torch.clamp_min(x, min=min_t) 6078 clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t) 6079 6080 self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu) 6081 6082 helper(2, 8, 4, 5) 6083 6084 # Test clamp_max 6085 6086 def test_clamp_max(self): 6087 def helper(n, c, h, w): 6088 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6089 x = cpu_x.detach().clone().to('mps') 6090 6091 cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False) 6092 max_t = cpu_max_t.detach().clone().to('mps') 6093 6094 clamp_max_result = torch.clamp_max(x, max=100.0) 6095 clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0) 6096 6097 self.assertEqual(clamp_max_result, clamp_max_result_cpu) 6098 6099 clamp_max_t_result = torch.clamp_max(x, max=max_t) 6100 clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t) 6101 6102 self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu) 6103 6104 helper(2, 8, 4, 5) 6105 6106 # Test clamp 6107 def test_clamp(self): 6108 def helper(n, c, h, w): 6109 import numpy as np 6110 upper_bound = 1000 6111 half_upper_bound = upper_bound / 2 6112 6113 # x=[0..1000) 6114 x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32) 6115 cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False) 6116 x = cpu_x.detach().clone().to('mps') 6117 6118 # x=[0..500) 6119 min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32) 6120 cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False) 6121 min_t = cpu_min_t.detach().clone().to('mps') 6122 6123 # x=[500..1000), to ensure max's are greater than mins 6124 max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound 6125 cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False) 6126 max_t = cpu_max_t.detach().clone().to('mps') 6127 6128 # [200..600]: just an arbitrary range between [0..1000] 6129 clamp_result = torch.clamp(x, min=200.0, max=600.0) 6130 clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0) 6131 self.assertEqual(clamp_result, clamp_result_cpu) 6132 6133 # test optional scalar refs and cached graph keys by passing only max 6134 clamp_opt_result = torch.clamp(x, max=600.0) 6135 clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0) 6136 self.assertEqual(clamp_opt_result, clamp_opt_result_cpu) 6137 6138 clamp_t_result = torch.clamp(x, min=min_t, max=max_t) 6139 clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t) 6140 self.assertEqual(clamp_t_result, clamp_t_result_cpu) 6141 6142 # test optional tensor refs and cached graph keys by passing only max 6143 clamp_topt_result = torch.clamp(x, max=max_t) 6144 clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t) 6145 self.assertEqual(clamp_topt_result, clamp_topt_result_cpu) 6146 6147 # test strided x 6148 clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0) 6149 clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0) 6150 self.assertEqual(clamp_result, clamp_result_cpu) 6151 6152 # test strided x, min_t, max_t 6153 clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1)) 6154 clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1)) 6155 self.assertEqual(clamp_result, clamp_result_cpu) 6156 6157 # test strided min_t, max_t 6158 clamp_result = torch.clamp( 6159 x.movedim(0, -1).clone(memory_format=torch.contiguous_format), 6160 min=min_t.movedim(0, -1), 6161 max=max_t.movedim(0, -1) 6162 ) 6163 clamp_result_cpu = torch.clamp( 6164 cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format), 6165 min=cpu_min_t.movedim(0, -1), 6166 max=cpu_max_t.movedim(0, -1) 6167 ) 6168 self.assertEqual(clamp_result, clamp_result_cpu) 6169 6170 # test inplace clamping 6171 x.clamp_(min=200.0, max=600.0) 6172 cpu_x.clamp_(min=200.0, max=600.0) 6173 self.assertEqual(cpu_x, x) 6174 6175 helper(2, 8, 4, 5) 6176 6177 def test_divmode(self): 6178 def helper(shape, rounding_mode): 6179 for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: 6180 if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or 6181 (rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False: 6182 cpu_x = None 6183 cpu_y = None 6184 if (dtype in [torch.float32, torch.float16]): 6185 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 6186 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 6187 else: 6188 cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) 6189 cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) 6190 6191 mps_x = cpu_x.detach().clone().to('mps') 6192 # clamp to avoid division by 0 6193 mps_y = cpu_y.detach().clone().to('mps') 6194 6195 if (rounding_mode == "floor_divide"): 6196 result_div_cpu = torch.floor_divide(cpu_x, cpu_y) 6197 result_div_mps = torch.floor_divide(mps_x, mps_y) 6198 self.assertEqual(result_div_mps, result_div_cpu) 6199 else: 6200 result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode) 6201 result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode) 6202 self.assertEqual(result_div_mps, result_div_cpu) 6203 6204 helper((2, 8, 4, 5), None) 6205 helper((2, 8, 4, 5), "floor") 6206 helper((2, 8, 4, 5), "trunc") 6207 helper((2, 8, 4, 5), "floor_divide") 6208 6209 def test_rounding(self): 6210 def helper(shape): 6211 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6212 mps_x = cpu_x.detach().clone().to('mps') 6213 6214 result_floor_cpu = torch.floor(cpu_x) 6215 result_floor_mps = torch.floor(mps_x) 6216 self.assertEqual(result_floor_mps, result_floor_cpu) 6217 6218 result_ceil_cpu = torch.ceil(cpu_x) 6219 result_ceil_mps = torch.ceil(mps_x) 6220 self.assertEqual(result_ceil_mps, result_ceil_cpu) 6221 6222 result_trunc_cpu = torch.trunc(cpu_x) 6223 result_trunc_mps = torch.trunc(mps_x) 6224 self.assertEqual(result_trunc_mps, result_trunc_cpu) 6225 6226 result_round_cpu = torch.round(cpu_x) 6227 result_round_mps = torch.round(mps_x) 6228 self.assertEqual(result_round_mps, result_round_cpu) 6229 6230 helper((2, 6, 3, 5)) 6231 helper((2, 8, 4, 5)) 6232 6233 def test_remainder(self): 6234 res_cpu = torch.remainder( 6235 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32)) 6236 res_mps = torch.remainder( 6237 torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32)) 6238 self.assertEqual(res_cpu, res_mps) 6239 6240 res_cpu = torch.remainder( 6241 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5) 6242 res_mps = torch.remainder( 6243 torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5) 6244 self.assertEqual(res_cpu, res_mps) 6245 6246 def test_expand(self): 6247 def helper(n, c): 6248 values = [[1.0], [4.0], [7.0]] 6249 cpu_x = torch.tensor(values, device='cpu') 6250 x = cpu_x.detach().clone().to('mps') 6251 6252 strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0)) 6253 strided_mps = torch.as_strided(x, (3, 4), (1, 0)) 6254 6255 self.assertEqual(strided_mps, strided_cpu) 6256 6257 helper(3, 1) 6258 6259 def test_im2col(self): 6260 def helper(x): 6261 return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3) 6262 x_cpu = torch.rand(1, 1, 200, 100) 6263 x = x_cpu.detach().clone().to('mps') 6264 self.assertEqual(helper(x_cpu), helper(x)) 6265 6266 def test_select(self): 6267 def helper(n, c): 6268 cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True) 6269 x = cpu_x.detach().clone().to('mps').requires_grad_() 6270 6271 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1)) 6272 strided_mps = torch.as_strided(x, (3, 1), (3, 1)) 6273 self.assertEqual(strided_mps, strided_cpu) 6274 6275 strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1)) 6276 strided_mps = torch.as_strided(x, (1, 3), (3, 1)) 6277 self.assertEqual(strided_mps, strided_cpu) 6278 6279 strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1) 6280 strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1) 6281 6282 self.assertEqual(strided_mps, strided_cpu) 6283 6284 helper(3, 3) 6285 6286 def test_sort(self): 6287 for SIZE in (4, 2049): 6288 device = 'mps' 6289 x = torch.rand(4, SIZE, device=device) 6290 res1val, res1ind = torch.sort(x) 6291 6292 res2val = torch.tensor((), device=device) 6293 res2ind = torch.tensor((), device=device, dtype=torch.long) 6294 torch.sort(x, out=(res2val, res2ind)) 6295 self.assertEqual(res1val, res2val, atol=0, rtol=0) 6296 self.assertEqual(res1ind, res2ind, atol=0, rtol=0) 6297 self.assertEqual(torch.argsort(x), res1ind) 6298 self.assertEqual(x.argsort(), res1ind) 6299 6300 self.assertEqual( 6301 torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0], 6302 torch.tensor((10, 20, 30, 40, 50), device=device), 6303 atol=0, rtol=0 6304 ) 6305 6306 def test_upsample_nearest2d(self): 6307 def helper(N, C, H, W, memory_format): 6308 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, 6309 requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format) 6310 inputCPU.retain_grad() 6311 inputMPS = inputCPU.detach().to('mps').requires_grad_() 6312 6313 values = [1, 2, 5, 10, 40] 6314 6315 for i in values: 6316 for j in values: 6317 upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j)) 6318 6319 outputCPU = upsample_nearest2d(inputCPU) 6320 outputMPS = upsample_nearest2d(inputMPS) 6321 6322 self.assertEqual(outputCPU, outputMPS) 6323 upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W)) 6324 6325 outputCPU = upsample_nearest2d(inputCPU) 6326 outputMPS = upsample_nearest2d(inputMPS) 6327 6328 self.assertEqual(outputCPU, outputMPS) 6329 6330 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3)) 6331 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3)) 6332 6333 self.assertEqual(inputCPU.grad, inputMPS.grad) 6334 6335 for memory_format in [torch.channels_last, torch.contiguous_format]: 6336 helper(1, 1, 4, 4, memory_format=memory_format) 6337 helper(7, 5, 3, 2, memory_format=memory_format) 6338 6339 def test_upsample_bilinear2d(self): 6340 def helper(N, C, H, W): 6341 inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, 6342 requires_grad=True).reshape(N, C, H, W) 6343 inputCPU.retain_grad() 6344 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 6345 6346 values = [1, 2, 5, 10, 40] 6347 6348 for i in values: 6349 for j in values: 6350 upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j)) 6351 6352 outputCPU = upsample_bilinear2d(inputCPU) 6353 outputMPS = upsample_bilinear2d(inputMPS) 6354 6355 self.assertEqual(outputCPU, outputMPS) 6356 6357 upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W)) 6358 6359 outputCPU = upsample_bilinear2d(inputCPU) 6360 outputMPS = upsample_bilinear2d(inputMPS) 6361 6362 self.assertEqual(outputCPU, outputMPS) 6363 6364 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3)) 6365 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3)) 6366 6367 self.assertEqual(inputCPU.grad, inputMPS.grad) 6368 6369 helper(1, 1, 4, 4) 6370 helper(7, 5, 3, 2) 6371 6372 def test_interpolate(self): 6373 def helper(shape, output_size, scales, mode, align_corners=False): 6374 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6375 inputCPU.retain_grad() 6376 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 6377 6378 # align_corners is used for 2D interpolation only 6379 if (align_corners is True and len(shape) > 3 and mode == 'bilinear'): 6380 if scales is not None: 6381 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners) 6382 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners) 6383 else: 6384 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners) 6385 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners) 6386 elif scales is not None: 6387 outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode) 6388 outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode) 6389 else: 6390 outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode) 6391 outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode) 6392 6393 self.assertEqual(outputCPU, outputMPS) 6394 6395 # backward pass (chose 0.6 just to have the grad_output != 1) 6396 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6)) 6397 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6)) 6398 self.assertEqual(inputCPU.grad, inputMPS.grad) 6399 6400 # 1D interpolation 6401 for mode in ['nearest', 'nearest-exact']: 6402 helper([2, 3, 4], [3], None, mode) # downsample with size 6403 helper([2, 3, 4], [6], None, mode) # upsample with size 6404 helper([2, 3, 4], None, [0.6], mode) # downsample with scale factor 6405 helper([2, 3, 4], None, [1.7], mode) # upsample with scale factor 6406 # 2D interpolation 6407 for mode in ['nearest', 'nearest-exact', 'bilinear']: 6408 helper([2, 3, 4, 5], [3, 4], None, mode) # downsample_nearest with size 6409 helper([2, 3, 4, 5], [6, 7], None, mode) # upsample_nearest with size 6410 helper([2, 3, 4, 5], None, [0.6, 0.7], mode) # downsample_nearest with scale factor 6411 helper([2, 3, 4, 5], None, [1.4, 1.7], mode) # upsample_nearest with scale factor 6412 # align_corners=True 6413 helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True) 6414 helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True) 6415 6416 # Test concat forward 6417 def test_cat1(self): 6418 def helper(shape_x, shape_y, shape_z): 6419 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False) 6420 x = cpu_x.detach().clone().to('mps') 6421 6422 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False) 6423 y = cpu_y.detach().clone().to('mps') 6424 6425 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False) 6426 z = cpu_z.detach().clone().to('mps') 6427 6428 cat = torch.cat([x, y, z], dim=1) 6429 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1) 6430 6431 self.assertEqual(cat, cat_cpu) 6432 6433 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5]) 6434 helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5]) 6435 helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5]) 6436 helper([2, 2, 6, 5], [0], [2, 5, 6, 5]) 6437 helper([0], [2, 3, 6, 5], [2, 5, 6, 5]) 6438 helper([2, 3, 4, 5], [2, 5, 4, 5], [0]) 6439 helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5]) 6440 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5]) 6441 helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5]) 6442 6443 # Test stack forward 6444 def test_stack(self): 6445 # All shapes must be same 6446 def helper(shape, dtype=torch.float32): 6447 6448 x, cpu_x = None, None 6449 y, cpu_y = None, None 6450 z, cpu_z = None, None 6451 6452 if (dtype not in [torch.float32, torch.bool]): 6453 cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) 6454 x = cpu_x.detach().clone().to('mps') 6455 cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) 6456 y = cpu_y.detach().clone().to('mps') 6457 cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False) 6458 z = cpu_z.detach().clone().to('mps') 6459 elif (dtype == torch.bool): 6460 cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 6461 x = cpu_x.detach().clone().to('mps') 6462 cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 6463 y = cpu_y.detach().clone().to('mps') 6464 cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False) 6465 z = cpu_z.detach().clone().to('mps') 6466 else: 6467 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6468 x = cpu_x.detach().clone().to('mps').requires_grad_() 6469 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6470 y = cpu_y.detach().clone().to('mps').requires_grad_() 6471 cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6472 z = cpu_z.detach().clone().to('mps').requires_grad_() 6473 6474 stack = torch.stack([x, y, z], dim=1) 6475 stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1) 6476 6477 self.assertEqual(stack, stack_cpu) 6478 6479 helper([2, 8, 4, 5]) 6480 helper([2, 8, 4, 5], dtype=torch.float16) 6481 helper([2, 8, 4, 5], dtype=torch.int32) 6482 helper([2, 8, 4, 5], dtype=torch.int64) 6483 helper([2, 8, 4, 5], dtype=torch.bool) 6484 # Empty test - Currently failing! Empty tensor not handled! 6485 # helper([0, 2, 4, 5]) 6486 6487 # Test abs 6488 def test_abs(self): 6489 def helper(shape): 6490 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6491 x = cpu_x.detach().clone().to('mps') 6492 6493 abs_result = torch.abs(x) 6494 abs_result_cpu = torch.abs(cpu_x) 6495 6496 self.assertEqual(abs_result, abs_result_cpu) 6497 6498 helper((2, 8, 4, 5)) 6499 6500 def test_log(self): 6501 def helper(shape): 6502 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6503 x = cpu_x.detach().clone().to('mps') 6504 6505 log_result = torch.log(x) 6506 log_result_cpu = torch.log(cpu_x) 6507 6508 self.assertEqual(log_result, log_result_cpu) 6509 6510 helper((2, 8, 4, 5)) 6511 6512 def test_log_ten(self): 6513 def helper(shape): 6514 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6515 x = cpu_x.detach().clone().to('mps') 6516 6517 log_ten_result = torch.log10(x) 6518 log_ten_result_cpu = torch.log10(cpu_x) 6519 6520 self.assertEqual(log_ten_result, log_ten_result_cpu) 6521 6522 helper((2, 8, 4, 5)) 6523 6524 def test_log_two(self): 6525 def helper(shape): 6526 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6527 x = cpu_x.detach().clone().to('mps') 6528 6529 log_two_result = torch.log2(x) 6530 log_two_result_cpu = torch.log2(cpu_x) 6531 6532 self.assertEqual(log_two_result, log_two_result_cpu) 6533 6534 helper((2, 8, 4, 5)) 6535 6536 def test_log1p(self): 6537 def helper(shape): 6538 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6539 x = cpu_x.detach().clone().to('mps') 6540 6541 log_result = torch.log1p(x) 6542 log_result_cpu = torch.log1p(cpu_x) 6543 6544 self.assertEqual(log_result, log_result_cpu) 6545 6546 helper((2, 8, 4, 5)) 6547 6548 def test_logaddexp(self): 6549 def helper(shape): 6550 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6551 x = cpu_x.detach().clone().to('mps') 6552 6553 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6554 y = cpu_y.detach().clone().to('mps') 6555 6556 log_result = torch.logaddexp(x, y) 6557 log_result_cpu = torch.logaddexp(cpu_x, cpu_y) 6558 6559 self.assertEqual(log_result, log_result_cpu) 6560 6561 helper((2, 8, 4, 5)) 6562 6563 def test_logaddexp2(self): 6564 def helper(shape): 6565 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6566 x = cpu_x.detach().clone().to('mps') 6567 6568 cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6569 y = cpu_y.detach().clone().to('mps') 6570 6571 log_result = torch.logaddexp2(x, y) 6572 log_result_cpu = torch.logaddexp2(cpu_x, cpu_y) 6573 6574 self.assertEqual(log_result, log_result_cpu) 6575 6576 helper((2, 8, 4, 5)) 6577 6578 def test_logsumexp(self): 6579 def helper(shape): 6580 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6581 x = cpu_x.detach().clone().to('mps') 6582 6583 log_result = torch.logsumexp(x, -1) 6584 log_result_cpu = torch.logsumexp(cpu_x, -1) 6585 6586 self.assertEqual(log_result, log_result_cpu) 6587 6588 helper((2, 8, 4, 5)) 6589 6590 # Test concat forward 6591 def test_cat2(self): 6592 6593 def helper1(shape_x, shape_y, shape_z, shape_w): 6594 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False) 6595 x = cpu_x.detach().clone().to('mps') 6596 6597 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False) 6598 y = cpu_y.detach().clone().to('mps') 6599 6600 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False) 6601 z = cpu_z.detach().clone().to('mps') 6602 6603 cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False) 6604 w = cpu_w.detach().clone().to('mps') 6605 6606 cat = torch.cat([x, y, z, w], dim=1) 6607 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1) 6608 6609 self.assertEqual(cat, cat_cpu) 6610 6611 def helper(shape_x, shape_y, shape_z): 6612 cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False) 6613 x = cpu_x.detach().clone().to('mps') 6614 6615 cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False) 6616 y = cpu_y.detach().clone().to('mps') 6617 6618 cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False) 6619 z = cpu_z.detach().clone().to('mps') 6620 6621 cat = torch.cat([x, y, z], dim=1) 6622 cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1) 6623 6624 self.assertEqual(cat, cat_cpu) 6625 6626 helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5]) 6627 helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5]) 6628 # Empty test - Currently failing! Empty tensor not handled! 6629 # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5]) 6630 6631 # Test isnan 6632 def test_isnan(self): 6633 def helper(shape): 6634 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 6635 nan_index = [random.randrange(0, shape[0])] 6636 # make a selected row inf 6637 cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan'))) 6638 x = cpu_x.detach().clone().to('mps') 6639 6640 isnan_result = torch.isnan(x) 6641 isnan_result_cpu = torch.isnan(cpu_x) 6642 6643 self.assertEqual(isnan_result, isnan_result_cpu) 6644 6645 helper((8, 2, 4, 5)) 6646 6647 # Test reciprocal 6648 def test_reciprocal(self): 6649 def helper(shape): 6650 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6651 x = cpu_x.detach().clone().to('mps').requires_grad_() 6652 6653 reciprocal_result = torch.reciprocal(x) 6654 reciprocal_result_cpu = torch.reciprocal(cpu_x) 6655 6656 cpu_grad = torch.ones_like(reciprocal_result_cpu) 6657 grad = cpu_grad.to('mps') 6658 6659 reciprocal_result.backward(gradient=grad) 6660 reciprocal_result_cpu.backward(gradient=cpu_grad) 6661 6662 self.assertEqual(reciprocal_result, reciprocal_result_cpu) 6663 self.assertEqual(x.grad, cpu_x.grad) 6664 6665 helper((2, 8, 4, 5)) 6666 6667 # Test sqrt 6668 def test_sqrt(self): 6669 def helper(shape): 6670 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6671 x = cpu_x.detach().clone().to('mps').requires_grad_() 6672 6673 sqrt_result = torch.sqrt(x) 6674 sqrt_result_cpu = torch.sqrt(cpu_x) 6675 6676 cpu_grad = torch.ones_like(sqrt_result_cpu) 6677 grad = cpu_grad.to('mps') 6678 6679 sqrt_result.backward(gradient=grad) 6680 sqrt_result_cpu.backward(gradient=cpu_grad) 6681 6682 self.assertEqual(sqrt_result, sqrt_result_cpu) 6683 self.assertEqual(x.grad, cpu_x.grad) 6684 6685 helper((2, 8, 4, 5)) 6686 6687 # Test selu, elu, celu 6688 def test_elu(self): 6689 def helper(shape, alpha=1.0, memory_format=torch.contiguous_format): 6690 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 6691 cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_() 6692 6693 x = cpu_x.detach().clone().to('mps').requires_grad_(True) 6694 for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]: 6695 elu_result = activation_func(x) 6696 elu_result_cpu = activation_func(cpu_x) 6697 6698 cpu_grad = torch.randn(elu_result_cpu.shape) 6699 grad = cpu_grad.to('mps') 6700 6701 elu_result.backward(gradient=grad) 6702 elu_result_cpu.backward(gradient=cpu_grad) 6703 6704 self.assertEqual(elu_result, elu_result_cpu) 6705 self.assertEqual(x.grad, cpu_x.grad) 6706 6707 # Test empty shape too 6708 for memory_fromat in [torch.channels_last, torch.contiguous_format]: 6709 for shape in [(2, 8, 4, 5)]: 6710 for alpha in [0.000001, 1.0, 2.3, 0.34, 23]: 6711 helper(shape, alpha, memory_fromat) 6712 6713 def test_elu_strided_output(self): 6714 # https://github.com/pytorch/pytorch/issues/124834 6715 elu_input = torch.randn(1, 1024, 500) 6716 alpha = float(1) 6717 inplace = False 6718 6719 elu_input_noncontiguous = elu_input.transpose(1, 2) 6720 self.assertEqual( 6721 F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace), 6722 F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace) 6723 ) 6724 6725 # Test glu 6726 def test_glu(self): 6727 def helper(shape, dim=0): 6728 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6729 x = cpu_x.detach().clone().to('mps').requires_grad_() 6730 6731 for activation_func in [torch.nn.GLU(dim=dim)]: 6732 glu_result = activation_func(x) 6733 glu_result_cpu = activation_func(cpu_x) 6734 6735 cpu_grad = torch.randn(glu_result_cpu.shape) 6736 grad = cpu_grad.to('mps') 6737 6738 glu_result.backward(gradient=grad) 6739 glu_result_cpu.backward(gradient=cpu_grad) 6740 6741 self.assertEqual(glu_result, glu_result_cpu) 6742 self.assertEqual(x.grad, cpu_x.grad) 6743 6744 for shape in [[4], (2, 4), (2, 8, 4, 6)]: 6745 for dim in range(len(shape)): 6746 helper(shape, dim) 6747 6748 # Test softplus 6749 def test_softplus(self): 6750 def helper(shape, beta, threshold, dtype): 6751 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True) 6752 x = cpu_x.detach().clone().to('mps').requires_grad_() 6753 6754 softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x) 6755 softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x) 6756 6757 cpu_grad = torch.randn(softplus_result.shape) 6758 grad = cpu_grad.to('mps') 6759 6760 softplus_result.backward(gradient=grad) 6761 softplus_result_cpu.backward(gradient=cpu_grad) 6762 6763 self.assertEqual(softplus_result, softplus_result_cpu) 6764 self.assertEqual(x.grad, cpu_x.grad) 6765 6766 # Test empty shape too 6767 for shape, beta, threshold, dtype in product( 6768 [(), (2, 3), (10, 10), (2, 3, 4, 5)], 6769 [0.5, 1, 2, 3, 4], 6770 [0.5, 20, 30, 40, 50], 6771 [torch.float16, torch.float32] 6772 ): 6773 helper(shape, beta, threshold, dtype) 6774 6775 # Test silu 6776 6777 def test_silu(self): 6778 def helper(shape): 6779 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 6780 x = cpu_x.detach().clone().to('mps').requires_grad_() 6781 6782 silu_result = torch.nn.SiLU()(x) 6783 silu_result_cpu = torch.nn.SiLU()(cpu_x) 6784 6785 cpu_grad = torch.randn(silu_result_cpu.shape) 6786 grad = cpu_grad.to('mps') 6787 6788 silu_result.backward(gradient=grad) 6789 silu_result_cpu.backward(gradient=cpu_grad) 6790 6791 self.assertEqual(silu_result, silu_result_cpu) 6792 self.assertEqual(x.grad, cpu_x.grad) 6793 6794 # Test empty shape too 6795 for shape in [[], (2, 3), (2, 8, 4, 5)]: 6796 helper(shape) 6797 6798 def test_cast_mps_to_cpu(self): 6799 def helper(src_dtype, dst_dtype): 6800 input = torch.rand((1, 3, 128, 128), dtype=src_dtype) 6801 input_cast_mps = input.to('mps') 6802 input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype) 6803 6804 # needs to match the initial Tensor 6805 self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype)) 6806 helper(torch.half, torch.float) 6807 helper(torch.float, torch.half) 6808 6809 def test_cast_mps_to_mps(self): 6810 def helper(src_dtype, dst_dtype): 6811 input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype) 6812 input_mps = input_cpu.to('mps') 6813 output_mps = input_mps.to(dtype=dst_dtype) 6814 output_cpu = input_cpu.to(dtype=dst_dtype) 6815 self.assertEqual(output_mps.cpu(), output_cpu) 6816 helper(torch.half, torch.float) 6817 helper(torch.float, torch.half) 6818 helper(torch.half, torch.long) 6819 helper(torch.float, torch.int) 6820 6821 def test_avg_pool2d_count_include_pad(self): 6822 cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True) 6823 x = cpu_x.detach().clone().to('mps').requires_grad_() 6824 pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True) 6825 ref_y = pool(cpu_x) 6826 y = pool(x) 6827 self.assertEqual(y, ref_y) 6828 cpu_grad = torch.randn(ref_y.shape) 6829 grad = cpu_grad.to('mps') 6830 ref_y.backward(gradient=cpu_grad) 6831 y.backward(gradient=grad) 6832 self.assertEqual(x.grad, cpu_x.grad) 6833 6834 # Test adaptive avg pool2d - when the input size is a multiple of output size 6835 # Not testing for channels last right now 6836 def test_adaptive_avg_pool2d_simple(self): 6837 def helper(input_shape, out_shape, channels_last): 6838 cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True) 6839 if (channels_last): 6840 cpu_x = cpu_x.to(memory_format=torch.channels_last) 6841 cpu_x.retain_grad() 6842 x = cpu_x.detach().clone().to('mps').requires_grad_() 6843 6844 avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x) 6845 avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x) 6846 6847 cpu_grad = torch.randn(avg_result_cpu.shape) 6848 grad = cpu_grad.to('mps') 6849 6850 avg_result.backward(gradient=grad) 6851 avg_result_cpu.backward(gradient=cpu_grad) 6852 6853 self.assertEqual(avg_result, avg_result_cpu) 6854 self.assertEqual(x.grad, cpu_x.grad) 6855 6856 helper((2, 2, 4, 4), (2, 2), False) 6857 helper((2, 2, 9, 9), (3, 3), False) 6858 helper((2, 2, 9, 9), (9, 9), False) 6859 helper((2, 2, 16, 16), (2, 2), False) 6860 helper((2, 2, 16, 16), (2, 16), False) 6861 6862 helper((2, 16, 16), (4, 4), False) 6863 6864 # Output shape larger than input shape 6865 6866 helper((2, 2, 4, 4), (8, 8), False) 6867 helper((2, 2, 2, 2), (4, 4), False) 6868 helper((2, 2, 3, 3), (9, 9), False) 6869 helper((2, 2, 2, 2), (16, 16), False) 6870 helper((2, 2, 2, 16), (16, 16), False) 6871 6872 helper((2, 4, 4), (16, 16), False) 6873 6874 try: 6875 helper((2, 2, 3, 3), (7, 7), False) 6876 except Exception as e: 6877 pass 6878 6879 # Test max avg pool2d - when the input size is a multiple of output size 6880 # Not testing for channels last right now 6881 def test_adaptive_max_pool2d_simple(self): 6882 def helper(input_shape, out_shape, return_indices, dtype, channels_last=False): 6883 cpu_x = None 6884 if (dtype in [torch.float16, torch.float32]): 6885 cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True) 6886 else: 6887 cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True) 6888 if (channels_last): 6889 cpu_x = cpu_x.to(memory_format=torch.channels_last) 6890 cpu_x.retain_grad() 6891 x = cpu_x.detach().clone().to('mps').requires_grad_() 6892 6893 max_result, max_indices = None, None 6894 max_result_cpu, max_indices_cpu = None, None 6895 6896 if (return_indices): 6897 max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) 6898 max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) 6899 else: 6900 max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) 6901 max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) 6902 6903 cpu_grad = torch.randn(max_result_cpu.shape) 6904 grad = cpu_grad.to('mps') 6905 6906 max_result.backward(gradient=grad) 6907 max_result_cpu.backward(gradient=cpu_grad) 6908 6909 self.assertEqual(max_result, max_result_cpu) 6910 if (return_indices): 6911 self.assertEqual(max_indices, max_indices_cpu) 6912 self.assertEqual(x.grad, cpu_x.grad) 6913 6914 for dtype in [torch.float32]: 6915 for return_indices in [False, True]: 6916 helper((2, 2, 4, 4), (2, 2), return_indices, dtype) 6917 helper((2, 2, 9, 9), (3, 3), return_indices, dtype) 6918 helper((2, 2, 9, 9), (9, 9), return_indices, dtype) 6919 helper((2, 2, 16, 16), (2, 2), return_indices, dtype) 6920 helper((2, 2, 16, 16), (2, 16), return_indices, dtype) 6921 helper((2, 16, 16), (4, 4), return_indices, dtype) 6922 6923 def test_gelu_simple(self): 6924 def helper(shape, dtype=torch.float, contiguous=True): 6925 cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 6926 x = cpu_x.detach().clone().to('mps') 6927 6928 if not contiguous and (0 not in shape and len(shape) >= 2): 6929 # Tranposing will make the tensor non-contiguous 6930 cpu_x = cpu_x.transpose(0, 1) 6931 x = x.transpose(0, 1) 6932 assert not x.is_contiguous() 6933 6934 cpu_x.requires_grad_() 6935 x.requires_grad_() 6936 6937 gelu_result = torch.nn.GELU()(x) 6938 # GELU is not supported on CPU, so cast it to float 6939 gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float)) 6940 6941 cpu_grad = torch.ones_like(gelu_result_cpu) 6942 grad = cpu_grad.to('mps') 6943 6944 gelu_result.backward(gradient=grad) 6945 gelu_result_cpu.backward(gradient=cpu_grad) 6946 6947 atol = 1e-5 if dtype == torch.float else 1e-2 6948 rtol = 1e-3 if dtype == torch.float else 1e-2 6949 self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol) 6950 6951 assert x.grad is not None # Check that the grad is well-populated 6952 self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol) 6953 6954 # Test empty shape too 6955 for dtype in [torch.float, torch.half]: 6956 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: 6957 for contiguous in [True, False]: 6958 helper(shape, dtype, contiguous) 6959 # Test that gelu would raise an assert for integral types 6960 for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: 6961 self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps"))) 6962 6963 def test_mish_simple(self): 6964 def helper(shape, dtype=torch.float, contiguous=True): 6965 cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 6966 x = cpu_x.detach().clone().to('mps') 6967 6968 if not contiguous and (0 not in shape and len(shape) >= 2): 6969 # Tranposing will make the tensor non-contiguous 6970 cpu_x = cpu_x.transpose(0, 1) 6971 x = x.transpose(0, 1) 6972 assert not x.is_contiguous() 6973 6974 cpu_x.requires_grad_() 6975 x.requires_grad_() 6976 6977 mish_result = torch.nn.Mish()(x) 6978 mish_result_cpu = torch.nn.Mish()(cpu_x) 6979 6980 cpu_grad = torch.ones_like(mish_result_cpu) 6981 grad = cpu_grad.to('mps') 6982 6983 mish_result.backward(gradient=grad) 6984 mish_result_cpu.backward(gradient=cpu_grad) 6985 6986 atol = 1e-5 if dtype == torch.float else 1e-2 6987 rtol = 1e-3 if dtype == torch.float else 1e-2 6988 self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol) 6989 6990 assert x.grad is not None # Check that the grad is well-populated 6991 self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol) 6992 6993 # Test empty shape too 6994 for dtype in [torch.float, torch.half]: 6995 for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]: 6996 for contiguous in [True, False]: 6997 helper(shape, dtype, contiguous) 6998 6999 def test_gelu(self): 7000 def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None): 7001 numpy_dtype = { 7002 torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double 7003 }[dtype] 7004 devices = ['cpu'] 7005 devices += ['mps'] 7006 7007 def _gelu_ref(X): 7008 return X * stats.norm.cdf(X) # noqa: F821 7009 7010 for d in devices: 7011 X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2] 7012 res = X 7013 ref = (X.to(numpy_dtype).cpu().detach().numpy()) 7014 self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False) 7015 7016 for n in [1, 5, 10]: 7017 for m in [1, 5, 10]: 7018 _test_gelu(n, m, torch.float32, True) 7019 _test_gelu(n, m, torch.float32, False) 7020 7021 # Test multi threaded 7022 num_threads = torch.get_num_threads() 7023 torch.set_num_threads(4) 7024 try: 7025 _test_gelu(32, 32, torch.float32, False) 7026 finally: 7027 torch.set_num_threads(num_threads) 7028 7029 def test_gelu_tanh(self): 7030 def helper(shape): 7031 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float) 7032 x = cpu_x.detach().clone().to('mps') 7033 7034 gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh') 7035 gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh') 7036 self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu) 7037 7038 helper((2, 8, 4, 5)) 7039 7040 # Test hardtanh 7041 def test_hardtanh(self): 7042 def helper(shape, min_val, max_val, inplace=False): 7043 cpu_x = None 7044 x = None 7045 7046 if (not inplace): 7047 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7048 x = cpu_x.detach().clone().to('mps').requires_grad_() 7049 else: 7050 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 7051 x = cpu_x.detach().clone().to('mps') 7052 7053 hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x) 7054 hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x) 7055 7056 self.assertEqual(hardtanh_result, hardtanh_result_cpu) 7057 7058 if (not inplace): 7059 cpu_grad = torch.randn(hardtanh_result_cpu.shape) 7060 grad = cpu_grad.to('mps') 7061 hardtanh_result.backward(gradient=grad) 7062 hardtanh_result_cpu.backward(gradient=cpu_grad) 7063 self.assertEqual(x.grad, cpu_x.grad) 7064 7065 # Test empty shape too 7066 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]: 7067 for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]): 7068 helper(shape, min_val, max_val) 7069 helper(shape, min_val, max_val, inplace=True) 7070 7071 def test_hardswish(self): 7072 def helper(shape, inplace=False, requires_grad=True): 7073 m = nn.Hardswish(inplace=inplace) 7074 7075 input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad) 7076 input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad) 7077 7078 if inplace and requires_grad: # check that both raise runtime error 7079 self.assertRaises(RuntimeError, lambda: m(input_cpu)) 7080 self.assertRaises(RuntimeError, lambda: m(input_mps)) 7081 return 7082 7083 output_cpu = m(input_cpu) 7084 output_mps = m(input_mps) 7085 7086 cpu_grad = torch.ones_like(output_cpu) 7087 mps_grad = cpu_grad.to('mps') 7088 7089 self.assertEqual(output_cpu, output_mps) 7090 7091 if requires_grad: 7092 output_cpu.backward(gradient=cpu_grad) 7093 output_mps.backward(gradient=mps_grad) 7094 7095 self.assertEqual(input_cpu.grad, input_mps.grad) 7096 7097 for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]: 7098 helper(shape, inplace=False, requires_grad=False) 7099 helper(shape, inplace=True, requires_grad=False) 7100 helper(shape, inplace=False, requires_grad=True) 7101 helper(shape, inplace=True, requires_grad=True) 7102 7103 def test_transpose_2D(self): 7104 values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] 7105 values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] 7106 cpu_x = torch.tensor(values, device='cpu') 7107 mps_x = torch.tensor(values, device='mps') 7108 mps_x1 = torch.tensor(values1, device='mps') 7109 7110 cpu_transpose = torch.transpose(cpu_x, 0, 1) 7111 mps_transpose = torch.transpose(mps_x, 0, 1) 7112 self.assertEqual(cpu_transpose, mps_transpose.to('cpu')) 7113 7114 def test_transpose_3D(self): 7115 values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]] 7116 cpu_x = torch.tensor(values, device='cpu') 7117 mps_x = torch.tensor(values, device='mps') 7118 7119 cpu_transpose1 = torch.transpose(cpu_x, 0, 1) 7120 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu') 7121 self.assertEqual(cpu_transpose1, mps_transpose1) 7122 7123 cpu_transpose2 = torch.transpose(cpu_x, 0, 2) 7124 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu') 7125 self.assertEqual(cpu_transpose2, mps_transpose2) 7126 7127 cpu_transpose3 = torch.transpose(cpu_x, 1, 2) 7128 mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu') 7129 self.assertEqual(cpu_transpose3, mps_transpose3) 7130 7131 7132 def test_transpose_4D(self): 7133 values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]], 7134 [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]] 7135 cpu_x = torch.tensor(values, device='cpu') 7136 mps_x = torch.tensor(values, device='mps') 7137 7138 cpu_transpose1 = torch.transpose(cpu_x, 0, 1) 7139 mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu') 7140 self.assertEqual(cpu_transpose1, mps_transpose1) 7141 7142 cpu_transpose2 = torch.transpose(cpu_x, 0, 2) 7143 mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu') 7144 self.assertEqual(cpu_transpose2, mps_transpose2) 7145 7146 cpu_transpose3 = torch.transpose(cpu_x, 0, 3) 7147 mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu') 7148 self.assertEqual(cpu_transpose3, mps_transpose3) 7149 7150 cpu_transpose4 = torch.transpose(cpu_x, 3, 1) 7151 mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu') 7152 self.assertEqual(cpu_transpose4, mps_transpose4) 7153 7154 cpu_transpose5 = torch.transpose(cpu_x, 3, 2) 7155 mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu') 7156 self.assertEqual(cpu_transpose5, mps_transpose5) 7157 7158 cpu_transpose6 = torch.transpose(cpu_x, 1, 2) 7159 mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu') 7160 self.assertEqual(cpu_transpose6, mps_transpose6) 7161 7162 # Test sign 7163 def test_sign(self): 7164 def helper(shape): 7165 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7166 x = cpu_x.detach().clone().to('mps').requires_grad_() 7167 7168 sign_result = torch.sign(x) 7169 sign_result_cpu = torch.sign(cpu_x) 7170 7171 cpu_grad = torch.ones_like(sign_result_cpu) 7172 grad = cpu_grad.to('mps') 7173 7174 sign_result.backward(gradient=grad) 7175 sign_result_cpu.backward(gradient=cpu_grad) 7176 7177 self.assertEqual(sign_result, sign_result_cpu) 7178 7179 helper((2, 8, 4, 5)) 7180 7181 def test_signbit(self): 7182 def helper(shape, dtype): 7183 cpu_x = torch.randn(shape, device='cpu').to(dtype) 7184 x = cpu_x.clone().to('mps') 7185 7186 signbit_result = torch.signbit(x) 7187 signbit_result_cpu = torch.signbit(cpu_x) 7188 7189 self.assertEqual(signbit_result, signbit_result_cpu) 7190 7191 helper((2, 8, 4, 5), torch.int) 7192 helper((2, 8, 4, 5), torch.float) 7193 helper((2, 8, 4, 5), torch.int64) 7194 7195 # Test neg 7196 def test_neg(self): 7197 def helper(shape): 7198 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7199 x = cpu_x.detach().clone().to('mps').requires_grad_() 7200 7201 neg_result = torch.neg(x) 7202 neg_result_cpu = torch.neg(cpu_x) 7203 7204 cpu_grad = torch.ones_like(neg_result_cpu) 7205 grad = cpu_grad.to('mps') 7206 7207 neg_result.backward(gradient=grad) 7208 neg_result_cpu.backward(gradient=cpu_grad) 7209 7210 self.assertEqual(neg_result, neg_result_cpu) 7211 7212 helper((2, 8, 4, 5)) 7213 7214 def test_neg_strided_input(self): 7215 # See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337 7216 x = torch.arange(18.0, device='mps').reshape(2, 3, 3) 7217 y = x.permute(1, 0, 2)[..., 1] 7218 z = y + y.neg() 7219 self.assertEqual(z.abs().max().item(), 0.0) 7220 7221 # Test index add 7222 def test_index_add(self): 7223 def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32): 7224 cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False) 7225 x = cpu_x.detach().clone().to('mps') 7226 7227 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype) 7228 idx = cpu_idx.detach().clone().to('mps') 7229 7230 cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False) 7231 source = cpu_source.detach().clone().to('mps') 7232 7233 idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha) 7234 idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha) 7235 self.assertEqual(idx_result, idx_result_cpu) 7236 7237 helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5) 7238 helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0) 7239 helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5) 7240 helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0) 7241 helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4) 7242 helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0) 7243 # test result dim=1 7244 helper((2,), 0, [1], (1,), 6.0) 7245 helper(2, 0, 1, 1, 6) 7246 # test float16 7247 helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16) 7248 7249 def test_index_64bit(self): 7250 """ Test that index operations work for 4Gb+ tensors """ 7251 if product_version < 14.0: 7252 raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039") 7253 # Cleanup memory 7254 gc.collect() 7255 torch.mps.empty_cache() 7256 # Check that index operations work for 4+GB tensors 7257 x = torch.rand(16000, 67120, device="mps") 7258 self.assertGreater(x.element_size() * x.numel(), 2**32) 7259 idx = torch.arange(0, 2, device="mps") 7260 x_sampled = x[:, idx] 7261 self.assertEqual(x[:, 0], x_sampled[:, 0]) 7262 # Reclaim memory after running the tests 7263 del x 7264 gc.collect() 7265 torch.mps.empty_cache() 7266 7267 def test_mm_large(self): 7268 """ Test that MM works for matrices with index larger than 32K """ 7269 x = torch.rand(10, 1, device="mps") 7270 y = torch.rand(1, 32769, device="mps") 7271 # This used to crash with: 7272 # error: subRange.start (24576) is not less than length of dimension[0] (16384) 7273 # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095 7274 self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0) 7275 7276 def compare_mm(m, n, k, dtype=torch.float): 7277 x = torch.rand(m, n, device="mps", dtype=dtype) 7278 y = torch.rand(n, k, device="mps", dtype=dtype) 7279 z = torch.mm(x, y).cpu() 7280 z_cpu = torch.mm(x.cpu(), y.cpu()) 7281 self.assertEqual(z, z_cpu) 7282 7283 # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal 7284 compare_mm(1024, 1, 32769) 7285 # one more time, but with dimensions inverted 7286 # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984 7287 compare_mm(32769, 1, 1025) 7288 7289 if product_version >= 14.0: 7290 # Test bfloat16 mm 7291 compare_mm(1024, 1, 32769, torch.bfloat16) 7292 7293 @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test") 7294 @unittest.skipIf(product_version < 14.0, "Can't allocate 4Gb tensor on MacOS 13") 7295 def test_copy_large(self): 7296 """ Test that copy of 4Gb+ tensors works """ 7297 x = torch.ones((2**30 + 11,), dtype=torch.float32) 7298 y = x.to(device="mps") 7299 self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps"))) 7300 del y 7301 del x 7302 7303 # Test flip 7304 def test_flip(self): 7305 def helper(shape, dims): 7306 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 7307 x = cpu_x.detach().clone().to('mps') 7308 7309 flip_result = torch.flip(x, dims=dims) 7310 flip_result_cpu = torch.flip(cpu_x, dims=dims) 7311 7312 self.assertEqual(flip_result, flip_result_cpu) 7313 7314 helper((2, 8, 4, 5), [0]) 7315 helper((8, 8, 4, 5), [0, 1]) 7316 helper((2, 8, 4, 5), (0, 1, 2, 3)) 7317 helper((2, 3, 3), (-1,)) 7318 # empty dims 7319 helper((2, 8, 4, 5), []) 7320 # input.numel() == 1 7321 helper((1,), (0,)) 7322 # input.numel() == 0 7323 helper((0,), (0,)) 7324 # none of dims that needs to be flipped 7325 helper((1, 3), [0]) 7326 7327 # Test index select 7328 def test_index_select(self): 7329 def helper(shape, dim, index, idx_dtype=torch.int32): 7330 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 7331 x = cpu_x.detach().clone().to('mps') 7332 7333 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype) 7334 idx = cpu_idx.detach().clone().to('mps') 7335 7336 idx_result = torch.index_select(x, dim=dim, index=idx) 7337 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx) 7338 7339 self.assertEqual(idx_result, idx_result_cpu) 7340 7341 helper((2, 8, 4, 5), 0, [1]) 7342 helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6]) 7343 helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6]) 7344 helper((2, 8, 4, 5), 2, [3, 0, 1]) 7345 helper((2, 8, 4, 5), 3, [2, 3, 0]) 7346 helper((2, 3, 3), -1, [1, 2]) 7347 helper((), 0, [0]) 7348 helper((5), 0, []) 7349 7350 def test_index_select_scalar(self): 7351 def helper(value, dim, index, idx_dtype=torch.int32): 7352 cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False) 7353 x = cpu_x.detach().clone().to('mps') 7354 7355 cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype) 7356 idx = cpu_idx.detach().clone().to('mps') 7357 7358 idx_result = torch.index_select(x, dim=dim, index=idx) 7359 idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx) 7360 7361 self.assertEqual(idx_result, idx_result_cpu) 7362 7363 helper(22, 0, [0]) 7364 with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"): 7365 helper(22, 0, []) 7366 7367 def test_embedding_dense_backward(self): 7368 def helper(n, d, m, idx): 7369 embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps') 7370 emedding_weight = embeddingMPS.weight.detach().cpu() 7371 W_MPS = torch.randn((m, d), requires_grad=True, device='mps') 7372 idx_MPS = torch.tensor(idx, device='mps') 7373 a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable 7374 a_MPS.retain_grad() 7375 b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place 7376 b_MPS.retain_grad() 7377 out_MPS = (a_MPS.unsqueeze(0) + b_MPS) 7378 loss_MPS = out_MPS.sigmoid().prod() 7379 loss_MPS.backward() 7380 7381 embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight) 7382 W_CPU = W_MPS.to('cpu') 7383 idx_CPU = torch.tensor(idx) 7384 a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable 7385 a_CPU.retain_grad() 7386 b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t() # modifies weight in-place 7387 b_CPU.retain_grad() 7388 out_CPU = (a_CPU.unsqueeze(0) + b_CPU) 7389 loss_CPU = out_CPU.sigmoid().prod() 7390 loss_CPU.backward() 7391 7392 self.assertEqual(b_CPU.grad, b_MPS.grad) 7393 self.assertEqual(a_CPU.grad, a_MPS.grad) 7394 7395 helper(3, 5, 7, [0, 1, 2]) 7396 helper(3, 6, 7, [0, 1, 2]) # verify if changes in shape would cause cached graph lookup problems 7397 helper(3, 5, 7, 2) # test scalar index 7398 7399 # Test pytorch gather 7400 def test_gather(self): 7401 def helper(shape, dim, idx_shape, idx_dtype=torch.int64): 7402 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7403 x = cpu_x.detach().clone().to('mps').requires_grad_() 7404 7405 # Indices should be taken from range of axis along which gathering is done 7406 idx_np = np.random.randint(0, shape[dim], idx_shape) 7407 7408 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7409 idx = cpu_idx.detach().clone().to('mps') 7410 7411 gather_result = torch.gather(x, dim=dim, index=idx) 7412 gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx) 7413 7414 cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float) 7415 grad = cpu_grad.to('mps') 7416 gather_result.backward(gradient=grad) 7417 gather_result_cpu.backward(gradient=cpu_grad) 7418 7419 self.assertEqual(gather_result, gather_result_cpu) 7420 self.assertEqual(cpu_x.grad, x.grad) 7421 7422 helper((6, 3, 3), 0, (3, 3, 3)) 7423 helper((2, 3, 3, 3), 0, (10, 3, 3, 3)) 7424 helper((2, 8, 4, 5), 0, (10, 8, 4, 5)) 7425 helper((2, 8, 4, 5), 0, (10, 6, 3, 2)) 7426 helper((8, 8, 4, 5), 0, (6, 8, 4, 5)) 7427 helper((8, 8, 4, 5), 0, (6, 7, 2, 3)) 7428 helper((2, 8, 4, 5), 1, (2, 5, 3, 4)) 7429 helper((2, 8, 4, 5), 2, (1, 8, 10, 3)) 7430 helper((2, 8, 4, 5), 3, (2, 5, 3, 12)) 7431 7432 # Test pytorch gather 7433 def test_gather_scalar(self): 7434 idx_dtype = torch.int64 7435 cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True) 7436 x = cpu_x.detach().clone().to('mps').requires_grad_() 7437 7438 idx_np = [0] 7439 7440 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7441 idx = cpu_idx.detach().clone().to('mps') 7442 7443 gather_result = torch.gather(x, dim=0, index=idx) 7444 gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx) 7445 7446 cpu_grad = torch.randn([1], device='cpu', dtype=torch.float) 7447 grad = cpu_grad.to('mps') 7448 gather_result.backward(gradient=grad) 7449 gather_result_cpu.backward(gradient=cpu_grad) 7450 7451 self.assertEqual(gather_result, gather_result_cpu) 7452 self.assertEqual(cpu_x.grad, x.grad) 7453 7454 # Test pytorch scatter_add and scatter 7455 def test_scatter_add(self): 7456 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True): 7457 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7458 x = cpu_x.detach().clone().to('mps').requires_grad_() 7459 7460 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True) 7461 src = cpu_src.detach().clone().to('mps').requires_grad_() 7462 7463 # Indices should be taken from range of axis along which gathering is done 7464 idx_np = None 7465 if (do_add): 7466 idx_np = np.random.randint(0, shape[dim], idx_shape) 7467 else: 7468 idx_np = np.array([[0, 1, 2], 7469 [1, 2, 3], 7470 [2, 3, 4], 7471 [3, 4, 5], 7472 [4, 5, 6]]) 7473 7474 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7475 idx = cpu_idx.detach().clone().to('mps') 7476 7477 scatter_result = None 7478 scatter_result_cpu = None 7479 7480 if (do_add): 7481 scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src) 7482 scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src) 7483 else: 7484 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src) 7485 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src) 7486 7487 cpu_grad = None 7488 grad = None 7489 7490 if (idx_shape == src_shape): 7491 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float) 7492 grad = cpu_grad.to('mps') 7493 scatter_result.backward(gradient=grad) 7494 scatter_result_cpu.backward(gradient=cpu_grad) 7495 7496 self.assertEqual(scatter_result, scatter_result_cpu) 7497 if (idx_shape == src_shape): 7498 self.assertEqual(cpu_x.grad, x.grad) 7499 self.assertEqual(cpu_src.grad, src.grad) 7500 7501 helper((2, 3), 0, (5, 3), (5, 3)) 7502 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5)) 7503 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5)) 7504 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2)) 7505 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2)) 7506 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5)) 7507 7508 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5)) 7509 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2)) 7510 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3)) 7511 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3)) 7512 7513 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8)) 7514 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6)) 7515 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6)) 7516 7517 # Test scatter src 7518 helper((8, 3), 0, (5, 3), (5, 3), do_add=False) 7519 helper((10, 3), 0, (5, 3), (5, 8), do_add=False) 7520 7521 # Test pytorch scatter_add and scatter for scalar input 7522 def test_scatter_add_scalar(self): 7523 def helper(idx_dtype=torch.int64, do_add=True): 7524 cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True) 7525 x = cpu_x.detach().clone().to('mps').requires_grad_() 7526 7527 cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True) 7528 src = cpu_src.detach().clone().to('mps').requires_grad_() 7529 7530 # Indices should be taken from range of axis along which gathering is done 7531 idx_np = [0] 7532 7533 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7534 idx = cpu_idx.detach().clone().to('mps') 7535 7536 scatter_result = None 7537 scatter_result_cpu = None 7538 7539 if (do_add): 7540 scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src) 7541 scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src) 7542 else: 7543 scatter_result = torch.scatter(x, dim=0, index=idx, src=src) 7544 scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src) 7545 7546 cpu_grad = None 7547 grad = None 7548 7549 cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float) 7550 grad = cpu_grad.to('mps') 7551 scatter_result.backward(gradient=grad) 7552 scatter_result_cpu.backward(gradient=cpu_grad) 7553 7554 self.assertEqual(scatter_result, scatter_result_cpu) 7555 self.assertEqual(cpu_x.grad, x.grad) 7556 self.assertEqual(cpu_src.grad, src.grad) 7557 7558 helper() 7559 helper(do_add=False) 7560 7561 # Test pytorch scatter_reduce 7562 def test_scatter_reduce(self): 7563 def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"): 7564 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7565 x = cpu_x.detach().clone().to('mps').requires_grad_() 7566 7567 cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True) 7568 src = cpu_src.detach().clone().to('mps').requires_grad_() 7569 7570 # Indices should be taken from range of axis along which gathering is done 7571 idx_np = np.random.randint(0, shape[dim], idx_shape) 7572 7573 cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype) 7574 idx = cpu_idx.detach().clone().to('mps') 7575 7576 scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str) 7577 scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str) 7578 7579 self.assertEqual(scatter_result, scatter_result_cpu) 7580 7581 # for reduce in ["sum", "prod", "amax", "amin"]: 7582 for reduce_type in ["add", "multiply"]: 7583 helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type) 7584 helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type) 7585 helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type) 7586 helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type) 7587 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type) 7588 helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type) 7589 7590 helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type) 7591 helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type) 7592 helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type) 7593 helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type) 7594 7595 helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type) 7596 helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type) 7597 helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type) 7598 7599 def test_is_nonzero(self): 7600 self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps'))) 7601 self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps'))) 7602 self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps'))) 7603 self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps'))) 7604 7605 # Test triu 7606 def test_triu(self): 7607 def helper(shape, diag=0): 7608 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7609 x = cpu_x.detach().clone().to('mps').requires_grad_() 7610 7611 triu_result = torch.triu(x, diag) 7612 triu_result_cpu = torch.triu(cpu_x, diag) 7613 7614 cpu_grad = torch.randn(triu_result_cpu.shape) 7615 grad = cpu_grad.to('mps') 7616 7617 triu_result.backward(gradient=grad) 7618 triu_result_cpu.backward(gradient=cpu_grad) 7619 7620 self.assertEqual(triu_result, triu_result_cpu) 7621 self.assertEqual(x.grad, cpu_x.grad) 7622 7623 helper((2, 8, 4, 5)) 7624 helper((2, 8, 4, 5), diag=1) 7625 helper((2, 8, 4, 5), diag=2) 7626 helper((2, 8, 4, 5), diag=3) 7627 helper((2, 8, 4, 5), diag=-1) 7628 helper((2, 8, 4, 5), diag=-2) 7629 helper((2, 8, 4, 5), diag=-3) 7630 7631 # Test inverse 7632 def test_inverse(self): 7633 def helper(n): 7634 cpu_input = torch.randn(n, n, device='cpu') 7635 mps_input = cpu_input.to('mps') 7636 7637 cpu_result = torch.linalg.inv(cpu_input) 7638 mps_result = torch.linalg.inv(mps_input) 7639 self.assertEqual(cpu_result, mps_result) 7640 7641 helper(2) 7642 helper(6) 7643 helper(3) 7644 helper(8) 7645 7646 # Test tril 7647 def test_tril(self): 7648 def helper(shape, diag=0): 7649 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7650 x = cpu_x.detach().clone().to('mps').requires_grad_() 7651 7652 tril_result = torch.tril(x, diag) 7653 tril_result_cpu = torch.tril(cpu_x, diag) 7654 7655 cpu_grad = torch.randn(tril_result_cpu.shape) 7656 grad = cpu_grad.to('mps') 7657 7658 tril_result.backward(gradient=grad) 7659 tril_result_cpu.backward(gradient=cpu_grad) 7660 7661 self.assertEqual(tril_result, tril_result_cpu) 7662 self.assertEqual(x.grad, cpu_x.grad) 7663 7664 helper((2, 8, 4, 5)) 7665 helper((2, 8, 4, 5), diag=1) 7666 helper((2, 8, 4, 5), diag=2) 7667 helper((2, 8, 4, 5), diag=3) 7668 helper((2, 8, 4, 5), diag=-1) 7669 helper((2, 8, 4, 5), diag=-2) 7670 helper((2, 8, 4, 5), diag=-3) 7671 7672 # test eye 7673 def test_eye(self): 7674 def helper(n, m, dtype): 7675 cpu_result = None 7676 result = None 7677 7678 if (n == m): 7679 cpu_result = torch.eye(n, dtype=dtype, device='cpu') 7680 result = torch.eye(n, dtype=dtype, device='mps') 7681 else: 7682 cpu_result = torch.eye(n, m, device='cpu') 7683 result = torch.eye(n, m, device='mps') 7684 7685 self.assertEqual(result, cpu_result) 7686 7687 for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]: 7688 helper(2, 2, dtype) 7689 helper(2, 3, dtype) 7690 helper(0, 2, dtype) 7691 helper(0, 0, dtype) 7692 helper(3, 8, dtype) 7693 helper(8, 3, dtype) 7694 7695 # Test diag 7696 def test_diag(self): 7697 def helper(shape, diag=0): 7698 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7699 x = cpu_x.detach().clone().to('mps').requires_grad_() 7700 7701 diag_result = torch.diag(x, diag) 7702 diag_result_cpu = torch.diag(cpu_x, diag) 7703 7704 # cpu_grad = torch.randn(diag_result_cpu.shape) 7705 # grad = cpu_grad.to('mps') 7706 7707 # diag_result.backward(gradient=grad) 7708 # diag_result_cpu.backward(gradient=cpu_grad) 7709 7710 self.assertEqual(diag_result, diag_result_cpu) 7711 # self.assertEqual(x.grad, cpu_x.grad) 7712 7713 for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]: 7714 for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]: 7715 helper(shape, diag=diag) 7716 7717 # Test linspace 7718 def test_linspace(self): 7719 def helper(start, end, steps, dtype=torch.float32): 7720 cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype) 7721 result = torch.linspace(start, end, steps, dtype=dtype, device='mps') 7722 self.assertEqual(cpu_result, result) 7723 7724 for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]: 7725 helper(2, 5, 10, dtype) 7726 helper(2, 2, 10, dtype) 7727 helper(5, 2, 10, dtype) 7728 helper(2, 2, 0, dtype) 7729 7730 # Test argange 7731 def test_arange(self): 7732 self.assertEqual(np.arange(10), torch.arange(10, device='mps')) 7733 self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps')) 7734 self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps')) 7735 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps')) 7736 7737 def test_arange_empty(self): 7738 out_mps = torch.tensor([], device="mps") 7739 out_cpu = torch.tensor([], device="cpu") 7740 7741 y_mps = torch.arange(0, 0, 1, out=out_mps) 7742 y_cpu = torch.arange(0, 0, 1, out=out_cpu) 7743 self.assertEqual(y_mps, y_cpu) 7744 7745 # Test rgange 7746 def test_range(self): 7747 self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps')) 7748 self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps')) 7749 self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps')) 7750 self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps')) 7751 7752 # Test softmax 7753 def test_softmax(self): 7754 def helper(shape, dim, channels_last=False): 7755 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 7756 if (channels_last): 7757 cpu_x = cpu_x.to(memory_format=torch.channels_last) 7758 cpu_x.retain_grad() 7759 x = cpu_x.detach().clone().to('mps').requires_grad_() 7760 7761 softmax_result = torch.nn.functional.softmax(x, dim=dim) 7762 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim) 7763 7764 # Currently NOT testing backward for channels last backward 7765 cpu_grad = None 7766 grad = None 7767 7768 if (not channels_last): 7769 cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float) 7770 grad = cpu_grad.to('mps') 7771 7772 softmax_result.backward(gradient=grad) 7773 softmax_result_cpu.backward(gradient=cpu_grad) 7774 7775 self.assertEqual(softmax_result, softmax_result_cpu) 7776 if (not channels_last): 7777 self.assertEqual(x.grad, cpu_x.grad) 7778 7779 def helper2(dim): 7780 cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True) 7781 x = cpu_x.detach().clone().to('mps').requires_grad_() 7782 7783 softmax_result = torch.nn.functional.softmax(x, dim=dim) 7784 softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim) 7785 7786 cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float) 7787 grad = cpu_grad.to('mps') 7788 7789 softmax_result.backward(gradient=grad) 7790 softmax_result_cpu.backward(gradient=cpu_grad) 7791 7792 self.assertEqual(softmax_result, softmax_result_cpu) 7793 self.assertEqual(x.grad, cpu_x.grad) 7794 7795 helper2(0) 7796 7797 for channels_last in [False]: 7798 for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]: 7799 if (len(shape) != 4 and channels_last): 7800 continue 7801 for dim in [0, 1, 2, 3, -1, -2, -3]: 7802 helper(shape, dim, channels_last) 7803 7804 def test_nan_to_num(self): 7805 inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14]) 7806 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 7807 outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0) 7808 outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0) 7809 self.assertEqual(outputMPS, outputCPU) 7810 7811 # Test where 7812 def test_where(self): 7813 def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float): 7814 7815 cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False) 7816 cond = cpu_cond.detach().clone().to('mps') 7817 7818 cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True) 7819 x = cpu_x.detach().clone().to('mps').requires_grad_() 7820 7821 cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True) 7822 y = cpu_y.detach().clone().to('mps').requires_grad_() 7823 7824 cpu_out = torch.where(cpu_cond, cpu_x, cpu_y) 7825 out = torch.where(cond, x, y) 7826 7827 cpu_grad = torch.randn(cpu_out.shape) 7828 grad = cpu_grad.to('mps') 7829 7830 cpu_out.backward(gradient=cpu_grad) 7831 out.backward(gradient=grad) 7832 7833 self.assertEqual(out, cpu_out) 7834 self.assertEqual(x.grad, cpu_x.grad) 7835 self.assertEqual(y.grad, cpu_y.grad) 7836 7837 for shape in ([(0, 3), [], (2, 3), (9,)]): 7838 helper(shape, shape, shape) 7839 7840 helper((2, 3, 1), (2, 3, 4), (2, 1, 4)) 7841 helper((2, 1, 1), (2, 3, 4), (1, 3, 4)) 7842 helper((1, 1, 1), (1, 1, 4), (2, 3, 1)) 7843 helper([], (1, 1, 4), (2, 3, 1)) 7844 helper([], (2, 3, 4), []) 7845 helper((5, 2, 3), (2, 3), (2, 3)) 7846 helper((2, 3), (5, 2, 3), (2, 3)) 7847 helper((2, 3), (2, 3), (5, 2, 3)) 7848 helper((2, 3), (5, 2, 3), (6, 5, 2, 3)) 7849 # Test that output is correctly resizes 7850 # TODO: Remove me when out OpInfo testing is enabled on MPS 7851 output = torch.tensor(0.0, device="mps") 7852 cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps") 7853 inp = torch.rand(3, 3, device="mps") 7854 other = torch.rand(3, 3, device="mps") 7855 out = torch.where(cond, inp, other, out=output) 7856 self.assertEqual(id(out), id(output)) 7857 self.assertEqual(out.shape, (3, 3)) 7858 7859 # Test normal 7860 def test_normal(self): 7861 def helper(shape, mean=0.0, std=1.0): 7862 mps_out = torch.normal(mean, std, shape, device='mps') 7863 7864 mean_array = np.ones(shape) 7865 mean_array *= mean 7866 cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False) 7867 mean_tensor = cpu_mean_tensor.detach().clone().to('mps') 7868 7869 std_array = np.ones(shape) 7870 std_array *= std 7871 cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False) 7872 std_tensor = cpu_std_tensor.detach().clone().to('mps') 7873 7874 # test out 7875 mps_out = torch.zeros(shape, device='mps') 7876 torch.normal(mean_tensor, std, out=mps_out) 7877 7878 mps_out = torch.zeros(shape, device='mps') 7879 torch.normal(mean, std_tensor, out=mps_out) 7880 7881 mps_out = torch.zeros(shape, device='mps') 7882 torch.normal(mean_tensor, std_tensor, out=mps_out) 7883 7884 # test without out 7885 mps_out = torch.normal(mean_tensor, std) 7886 self.assertEqual(mps_out.size(), mean_tensor.size()) 7887 7888 mps_out = torch.normal(mean, std_tensor) 7889 self.assertEqual(mps_out.size(), std_tensor.size()) 7890 7891 inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size()) 7892 mps_out = torch.normal(mean_tensor, std_tensor) 7893 self.assertEqual(mps_out.size(), inferred_shape) 7894 7895 helper((2, 3, 4, 5, 6)) 7896 helper((100, 100), 2.5, 1.2) 7897 7898 def test_bernoulli(self): 7899 shape = (10, 10) 7900 all_ones = torch.ones(shape, device='mps') 7901 all_zeros = torch.zeros(shape, device='mps') 7902 7903 prob_tensor = all_ones * 0.5 7904 # probability of drawing "1" is 0.5 7905 mps_out = torch.bernoulli(prob_tensor) 7906 # We can't check reliably the mean and std. 7907 # Just make sure we don't return constant values 7908 self.assertNotEqual(mps_out.to('cpu').mean(), 0.) 7909 self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.) 7910 7911 # probability of drawing "1" is 0 7912 mps_out = torch.bernoulli(all_zeros) 7913 self.assertEqual(mps_out, all_zeros) 7914 7915 # probability of drawing "1" is 1 7916 mps_out = torch.bernoulli(all_ones) 7917 self.assertEqual(mps_out, all_ones) 7918 7919 # Check it works for different dtypes 7920 for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]: 7921 mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5) 7922 # Check that output is not all zeros or ones 7923 if product_version > 13.0: 7924 uniq = mps_out.unique() 7925 self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype)) 7926 else: 7927 self.assertEqual(mps_out.min().item(), 0.) 7928 self.assertEqual(mps_out.max().item(), 1.) 7929 7930 def test_mps_generator(self): 7931 # explicit manual seeding by creating an MPS Generator 7932 g_mps = torch.Generator(device='mps') 7933 g_mps.manual_seed(999) 7934 mps_x = torch.randn(5, device='mps', generator=g_mps) 7935 g_mps.manual_seed(999) 7936 # generate random numbers with offset `0` 7937 mps_y = torch.randn(5, device='mps', generator=g_mps) 7938 # seed values were the same, so the random tensor contents should match 7939 self.assertEqual(mps_x, mps_y) 7940 # save generator's state (offset = 1) to restore it later 7941 g_state = g_mps.get_state() 7942 7943 # generate random numbers with offset `1` 7944 mps_x = torch.randn(5, device='mps', generator=g_mps) 7945 # in this case, the random results must differ from the last generated random results 7946 self.assertNotEqual(mps_x, mps_y) 7947 7948 # mps_x was produced by g_state, we use it as our reference mps_y. 7949 mps_y = mps_x 7950 7951 # restore the previously saved state, and the results should match again 7952 g_mps.set_state(g_state) 7953 mps_x = torch.randn(5, device='mps', generator=g_mps) 7954 self.assertEqual(mps_x, mps_y) 7955 7956 @serialTest() 7957 def test_default_mps_generator(self): 7958 # manual seeding on the "default" MPS generator using 7959 # the global torch.manual_seed() 7960 torch.manual_seed(230) 7961 mps_x = torch.randn(5, device='mps') 7962 # manual seeding using torch.mps.manual_seed() 7963 # which should set the "default" MPS generator 7964 # like the global torch.manual_seed() 7965 torch.mps.manual_seed(230) 7966 # generate random numbers with offset `0` 7967 mps_y = torch.randn(5, device='mps') 7968 # seed values were the same, so the random tensor contents should match 7969 self.assertEqual(mps_x, mps_y) 7970 7971 # save the default generator's state (offset = 1) to restore it later 7972 g_state = torch.mps.get_rng_state() 7973 7974 # generate random numbers with offset `1` 7975 mps_x = torch.randn(5, device='mps') 7976 # in this case, the random results must differ from the last generated random results 7977 self.assertNotEqual(mps_x, mps_y) 7978 # since we called randn twice after seeding, the offset should be 2 7979 self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2) 7980 7981 # mps_x was produced by g_state, we use it as our reference mps_y. 7982 mps_y = mps_x 7983 7984 # restore the previously saved state to the "default" MPS generator, and the results should match again 7985 torch.mps.set_rng_state(g_state) 7986 mps_x = torch.randn(5, device='mps') 7987 self.assertEqual(mps_x, mps_y) 7988 7989 def test_device_synchronize(self): 7990 # just running some ops each followed by a synchronize to wait for 7991 # MPS stream to finish running each of them 7992 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ 7993 .to(device='mps', dtype=torch.float) 7994 7995 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 7996 torch.mps.synchronize() 7997 x = net1(x) 7998 torch.mps.synchronize() 7999 x.backward(torch.randn_like(x)) 8000 torch.mps.synchronize() 8001 8002 @serialTest() 8003 def test_mps_allocator_module(self): 8004 # first garbage collect and empty the cached blocks 8005 gc.collect() 8006 torch.mps.empty_cache() 8007 # measure memory allocations from MPSAllocator 8008 current_alloc_before = torch.mps.current_allocated_memory() 8009 # after garbage collection and emptying the cache the 8010 # current_allocated_memory must be zero 8011 self.assertEqual(current_alloc_before, 0) 8012 # measure total memory allocations from Metal driver 8013 driver_alloc_before = torch.mps.driver_allocated_memory() 8014 # allocate a new 8 MB tensor to force allocation of a new Metal Heap 8015 x = torch.ones(1024 * 1024 * 8, device="mps") 8016 # get memory allocations after allocating tensor x 8017 current_alloc_after = torch.mps.current_allocated_memory() 8018 driver_alloc_after = torch.mps.driver_allocated_memory() 8019 # current and driver memory allocations must have 8020 # grown at this point 8021 self.assertGreater(current_alloc_after, current_alloc_before) 8022 self.assertGreater(driver_alloc_after, driver_alloc_before) 8023 8024 def test_mps_allocator_stats(self): 8025 max_memory = torch.mps.recommended_max_memory() 8026 print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB") 8027 self.assertGreater(max_memory, 0) 8028 8029 # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool, 8030 # press record, then run this python test, and press stop. Next expand 8031 # the os_signposts->PyTorchMPS and check if events or intervals are logged 8032 # like this example: 8033 # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)" 8034 def test_mps_profiler_module(self): 8035 with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p: 8036 # just running some ops to capture the OS Signposts traces for profiling 8037 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ 8038 .to(device='mps', dtype=torch.float) 8039 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 8040 x = net1(x) 8041 8042 torch.mps.profiler.start(mode="interval", wait_until_completed=True) 8043 # just running some ops to capture the OS Signposts traces for profiling 8044 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 8045 x = net1(x) 8046 torch.mps.profiler.stop() 8047 8048 def test_mps_event_module(self): 8049 startEvent = torch.mps.Event(enable_timing=True) 8050 startEvent.record() 8051 net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ 8052 .to(device='mps', dtype=torch.float) 8053 x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) 8054 x = net1(x) 8055 endEvent = torch.mps.Event(enable_timing=True) 8056 endEvent.record() 8057 elapsedTime = startEvent.elapsed_time(endEvent) 8058 self.assertGreater(elapsedTime, 0.0) 8059 8060 def test_jit_save_load(self): 8061 m = torch.nn.Module() 8062 m.x = torch.rand(3, 3, device='mps') 8063 buffer = io.BytesIO() 8064 torch.jit.save(torch.jit.script(m), buffer) 8065 buffer.seek(0) 8066 n = torch.jit.load(buffer) 8067 self.assertEqual(n.x, m.x) 8068 8069 # Test random_, random_.to and random_.from 8070 def test_random(self): 8071 def helper(shape, low, high, dtype=torch.int32): 8072 8073 mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps') 8074 8075 # We can't check reliably the mean and std. 8076 # Just make sure we don't return constant values 8077 self.assertNotEqual(mps_out.float().mean().item(), 0.) 8078 self.assertNotEqual(mps_out.float().std().item(), 0.) 8079 8080 helper([100, 100], 0, 10) 8081 helper([100, 100], 23, 89) 8082 helper([100, 100], 23, 89, dtype=torch.float32) 8083 helper([100, 100], 23, 89, dtype=torch.int64) 8084 helper([100, 100], 0, 2, dtype=torch.bool) 8085 8086 # Test random_ 8087 for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]: 8088 x = torch.empty(10, 10, dtype=dtype, device='mps') 8089 x.random_() 8090 self.assertNotEqual(x.max().item(), 0) 8091 8092 # Test exponential 8093 def test_exponential(self): 8094 def helper(shape, lamda, dtype=torch.float32): 8095 8096 mps_out = torch.zeros(shape, device='mps', dtype=dtype) 8097 mps_out.exponential_(lamda) 8098 8099 print(mps_out.to('cpu').float().mean(), 1 / lamda) 8100 print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2)) 8101 8102 for dtype in [torch.float32, torch.float16]: 8103 helper([100, 100], 2, dtype) 8104 helper([100, 100], 1, dtype) 8105 helper([100, 100], 3, dtype) 8106 helper([100, 100], 0.5, dtype) 8107 8108 def test_exponential_1(self): 8109 rate = torch.randn(5, 5).abs().requires_grad_() 8110 rate_1d = torch.randn(1).abs().requires_grad_() 8111 self.assertEqual(Exponential(rate).sample().size(), (5, 5)) 8112 self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5)) 8113 self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1)) 8114 self.assertEqual(Exponential(rate_1d).sample().size(), (1,)) 8115 self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,)) 8116 self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,)) 8117 8118 # Test add 8119 def test_add_sub(self): 8120 def helper(shape, alpha, op_name, inplace): 8121 if op_name == "add": 8122 op = torch.Tensor.add_ if inplace else torch.add 8123 elif op_name == "sub": 8124 op = torch.Tensor.sub_ if inplace else torch.sub 8125 8126 for dtype in [torch.float16, torch.float32]: 8127 cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 8128 mps_x = cpu_x.detach().clone().to('mps') 8129 8130 cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) 8131 mps_y = cpu_y.detach().clone().to('mps') 8132 8133 cpu_out = op(cpu_x, cpu_y, alpha=alpha) 8134 mps_out = op(mps_x, mps_y, alpha=alpha) 8135 # fp16 isn't accurate when alpha is passed 8136 # TODO: remove or fix 'tol' when we fix problems with fp16 8137 tol = 2e-3 if dtype is torch.float16 else None 8138 self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol) 8139 if not (cpu_y.shape != () and inplace): # in-place output cannot be broadcasted. 8140 # create a scalar tensor 8141 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) 8142 mps_s = cpu_s.detach().clone().to('mps') 8143 # primary tensor is scalar 8144 self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y)) 8145 # create a scalar tensor 8146 cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) 8147 mps_s = cpu_s.detach().clone().to('mps') 8148 # secondary tensor is scalar 8149 self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol) 8150 8151 8152 for op_name, inplace in product(["add", "sub"], [True, False]): 8153 helper((), 0.0, op_name, inplace) 8154 helper((2, 8, 4, 5), 0.0, op_name, inplace) 8155 helper((2, 8, 4, 5), 0.1, op_name, inplace) 8156 helper((2, 8, 4, 5), 1.0, op_name, inplace) 8157 helper((2, 8, 3, 5), 0.1, op_name, inplace) 8158 helper((2, 8, 3, 5), 0.2, op_name, inplace) 8159 8160 # Test add 8161 def test_add_scalars(self): 8162 def helper(alpha): 8163 for dtype in [torch.float16, torch.float32]: 8164 cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False) 8165 x = cpu_x.detach().clone().to('mps') 8166 8167 cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False) 8168 y = cpu_y.detach().clone().to('mps') 8169 8170 cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha) 8171 out = torch.add(x, y, alpha=alpha) 8172 # fp16 isn't accurate when alpha is passed 8173 tol = 1e-3 if dtype is torch.float16 else None 8174 self.assertEqual(out, cpu_out, rtol=tol, atol=tol) 8175 8176 helper(1.0) 8177 helper(0.0) 8178 helper(0.1) 8179 helper(0.2) 8180 8181 # Test int32 tensor + int64 scalar add 8182 # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534 8183 x = torch.ones(4, dtype=torch.int32, device='mps') 8184 self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps')) 8185 self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps'))) 8186 8187 def test_types_binary_op(self): 8188 # Float * Bool 8189 cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu") 8190 mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps") 8191 self.assertEqual(cpu_x, mps_x) 8192 # Float * Int64 8193 cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu") 8194 mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps") 8195 self.assertEqual(cpu_y, mps_y) 8196 8197 def test_unary_ops(self): 8198 def helper(shape, op): 8199 for dtypef in [torch.float32]: 8200 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False) 8201 mps_x = cpu_x.detach().clone().to('mps') 8202 self.assertEqual(op(cpu_x), op(mps_x)) 8203 8204 for dtypei in [torch.int32, torch.int16]: 8205 cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False) 8206 mps_x = cpu_x.to('mps') 8207 self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4) 8208 # test slice 8209 for dtypef in [torch.float32]: 8210 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False) 8211 mps_x = cpu_x.detach().clone().to('mps') 8212 cpu_slice = cpu_x[:, ::2, :, :] 8213 mps_slice = mps_x[:, ::2, :, :] 8214 self.assertEqual(op(cpu_slice), op(mps_slice)) 8215 # test view 8216 for dtypef in [torch.float32]: 8217 cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False) 8218 mps_x = cpu_x.detach().clone().to('mps') 8219 # create view of tensor by reducing the 3rd and 4th dimension 8220 combined_dim = shape[-1] * shape[-2] 8221 reshaped_dims = list(shape[:-2]) + [combined_dim] 8222 cpu_view = cpu_x.view(*reshaped_dims) 8223 mps_view = mps_x.view(*reshaped_dims) 8224 self.assertEqual(op(cpu_view), op(mps_view)) 8225 8226 helper((2, 8, 4, 5), torch.exp) 8227 helper((2, 8, 3, 5), torch.exp2) 8228 helper((2, 8, 3, 5), torch.expm1) 8229 helper((2, 8, 3, 5), torch.log) 8230 helper((2, 8, 3, 5), torch.cos) 8231 helper((2, 8, 3, 5), torch.erfinv) 8232 8233 8234 def test_non_dense_in_storage_unary_ops(self): 8235 def helper(op): 8236 for dtypef in [torch.float32]: 8237 cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False) 8238 mps_x = cpu_x.detach().clone().to('mps') 8239 self.assertEqual(op(cpu_x[::2]), op(mps_x[::2])) 8240 8241 for dtypei in [torch.int32, torch.int16, torch.int8]: 8242 cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False) 8243 mps_x = cpu_x.to('mps') 8244 self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4) 8245 8246 helper(torch.exp) 8247 helper(torch.exp2) 8248 helper(torch.expm1) 8249 helper(torch.log) 8250 helper(torch.cos) 8251 8252 def test_unary_ops_storage_offset_strided(self): 8253 def helper(shape, op, inplace, dtype=torch.float32): 8254 # test in-place with storage_offset 8255 cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 8256 mps_x = cpu_x.detach().clone().to('mps') 8257 y = op(mps_x[1]) 8258 cpu_y = op(cpu_x[1]) 8259 self.assertEqual(y, cpu_y) 8260 8261 8262 # See https://github.com/pytorch/pytorch/issues/100764 8263 if not inplace: 8264 cpu_x = torch.randn(shape, device='cpu', dtype=dtype) 8265 mps_x = cpu_x.detach().clone().to('mps') 8266 cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t() 8267 mps_y = cpu_y.detach().clone().to('mps') 8268 op(cpu_x, out=cpu_y) 8269 op(mps_x, out=mps_y) 8270 self.assertEqual(mps_y, cpu_y) 8271 8272 8273 helper((5, 5), torch.exp, False) 8274 helper((5, 5), torch.cos, False) 8275 helper((5, 5), torch.neg, False) 8276 helper((5, 5), torch.tanh, False) 8277 helper((5, 5), torch.tanh_, True) 8278 8279 def test_atan2(self): 8280 def helper(shape): 8281 input_cpu = torch.randn(shape) 8282 input_mps = input_cpu.detach().clone().to("mps") 8283 8284 other_cpu = torch.randn(shape) 8285 other_mps = other_cpu.detach().clone().to("mps") 8286 8287 atan2_cpu = torch.atan2(input_cpu, other_cpu) 8288 atan2_mps = torch.atan2(input_mps, other_mps) 8289 8290 self.assertEqual(atan2_cpu, atan2_mps.to("cpu")) 8291 8292 helper(4) 8293 helper(10000) 8294 helper((10000, 40)) 8295 8296 def test_multinomial(self): 8297 # Test with num_dist = 1 8298 def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True): 8299 cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False) 8300 prob_tensor = cpu_prob_tensor.detach().clone().to('mps') 8301 8302 mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement) 8303 if (not replacement): 8304 print(mps_out.to('cpu')) 8305 else: 8306 # Compare "real" with theoretical values 8307 print(mps_out.to('cpu').float().mean(), compare_mean) 8308 print(mps_out.to('cpu').float().std() ** 2, compare_var) 8309 8310 # TODO: Add tests for data types 8311 helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000) 8312 helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000) 8313 helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000) 8314 helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000) 8315 helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False) 8316 8317 def test_cumsum_dim_check(self): 8318 x = torch.rand((3, 3), device="mps") 8319 self.assertEqual(x.cumsum(1), x.cumsum(-1)) 8320 self.assertEqual(x.cumsum(0), x.cumsum(-2)) 8321 self.assertRaises(IndexError, lambda: x.cumsum(2)) 8322 self.assertRaises(IndexError, lambda: x.cumsum(-3)) 8323 8324 def test_cumprod_dim_check(self): 8325 x = torch.rand((3, 3), device="mps") 8326 self.assertEqual(x.cumprod(1), x.cumprod(-1)) 8327 self.assertEqual(x.cumprod(0), x.cumprod(-2)) 8328 self.assertRaises(IndexError, lambda: x.cumprod(2)) 8329 self.assertRaises(IndexError, lambda: x.cumprod(-3)) 8330 8331class TestLogical(TestCaseMPS): 8332 def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): 8333 return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad) 8334 8335 def test_logical_not(self): 8336 def helper(x): 8337 cpu_x = x 8338 x = cpu_x.detach().clone().to('mps') 8339 8340 result = torch.logical_not(x) 8341 result_cpu = torch.logical_not(cpu_x) 8342 8343 self.assertEqual(result, result_cpu) 8344 8345 helper(self._wrap_tensor([1, 1, 0, 0])) 8346 helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True)) 8347 helper(self._wrap_tensor([True, True, False, False])) 8348 helper(self._wrap_tensor(1)) 8349 helper(self._wrap_tensor(0)) 8350 helper(self._wrap_tensor(True)) 8351 helper(self._wrap_tensor(False)) 8352 8353 def test_logical_and(self): 8354 def helper(x, other): 8355 cpu_x = x 8356 x = cpu_x.detach().clone().to('mps') 8357 8358 cpu_other = other 8359 other = cpu_other.detach().clone().to('mps') 8360 8361 result = torch.logical_and(x, other) 8362 result_cpu = torch.logical_and(cpu_x, cpu_other) 8363 self.assertEqual(result, result_cpu) 8364 8365 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1])) 8366 helper( 8367 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), 8368 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) 8369 ) 8370 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) 8371 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) 8372 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) 8373 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) 8374 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) 8375 8376 def test_logical_or(self): 8377 def helper(x, other): 8378 cpu_x = x 8379 x = cpu_x.detach().clone().to('mps') 8380 8381 cpu_other = other 8382 other = cpu_other.detach().clone().to('mps') 8383 8384 result = torch.logical_or(x, other) 8385 result_cpu = torch.logical_or(cpu_x, cpu_other) 8386 8387 self.assertEqual(result, result_cpu) 8388 8389 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1])) 8390 helper( 8391 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), 8392 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) 8393 ) 8394 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) 8395 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) 8396 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) 8397 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) 8398 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) 8399 8400 def test_logical_xor(self): 8401 def helper(x, other): 8402 cpu_x = x 8403 x = cpu_x.detach().clone().to('mps') 8404 8405 cpu_other = other 8406 other = cpu_other.detach().clone().to('mps') 8407 8408 result = torch.logical_xor(x, other) 8409 result_cpu = torch.logical_xor(cpu_x, cpu_other) 8410 8411 self.assertEqual(result, result_cpu) 8412 8413 helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1])) 8414 helper( 8415 self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True), 8416 self._wrap_tensor([1, 0, 0, 1], dtype=torch.float) 8417 ) 8418 helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True])) 8419 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1)) 8420 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0)) 8421 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True)) 8422 helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False)) 8423 8424 def test_min_max(self): 8425 def helper(dtype): 8426 for _ in range(10): 8427 if dtype == torch.float32 or dtype == torch.float16: 8428 x = torch.randn((30, 15), device='mps', dtype=dtype) 8429 else: 8430 x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype) 8431 x_cpu = x.to("cpu") 8432 8433 y = x.max() 8434 y_cpu = x_cpu.max() 8435 self.assertEqual(y, y_cpu) 8436 8437 z = x.min() 8438 z_cpu = x_cpu.min() 8439 self.assertEqual(z, z_cpu) 8440 8441 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]] 8442 8443 def test_min_max_nan_propagation(self): 8444 def helper(dtype): 8445 cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu") 8446 mps_x = cpu_x.detach().clone().to('mps') 8447 8448 cpu_max = torch.max(cpu_x) 8449 mps_max = torch.max(mps_x).to('cpu') 8450 8451 cpu_amax = torch.amax(cpu_x) 8452 mps_amax = torch.amax(mps_x).to('cpu') 8453 8454 cpu_min = torch.min(cpu_x) 8455 mps_min = torch.min(mps_x).to('cpu') 8456 8457 cpu_amin = torch.amin(cpu_x) 8458 mps_amin = torch.amin(mps_x).to('cpu') 8459 8460 self.assertEqual(cpu_max, mps_max) 8461 self.assertEqual(cpu_amax, mps_amax) 8462 self.assertEqual(cpu_min, mps_min) 8463 self.assertEqual(cpu_amin, mps_amin) 8464 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]] 8465 8466 def test_isin(self): 8467 def helper(dtype): 8468 shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]), 8469 ([5], [10]), ([0], [5]), ([5], [0])] 8470 for shape_tuple in shapes: 8471 for inverted in [True, False]: 8472 if dtype.is_floating_point: 8473 # Half is not supported for CPU isin. Compute reference in FP32 8474 A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32) 8475 B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32) 8476 else: 8477 A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype) 8478 B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype) 8479 8480 A_mps = A.clone().detach().to('mps') 8481 B_mps = B.clone().detach().to('mps') 8482 8483 cpu_ref = torch.isin(A, B, invert=inverted) 8484 if dtype in [torch.float16, torch.bfloat16]: 8485 cpu_ref.type(dtype) 8486 8487 mps_out = torch.isin(A_mps, B_mps, invert=inverted) 8488 self.assertEqual(mps_out, cpu_ref) 8489 8490 dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8] 8491 if product_version < 14.0: 8492 # Int types expected to fail on MacOS < 14.0 8493 dtypes = [torch.float32, torch.float16, torch.bfloat16] 8494 8495 [helper(dtype) for dtype in dtypes] 8496 8497 def test_isin_asserts(self): 8498 A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32) 8499 B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16) 8500 with self.assertRaisesRegex(RuntimeError, 'Expected elements.dtype()*'): 8501 out = torch.isin(A, B) 8502 8503 8504 C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32) 8505 D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32) 8506 with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'): 8507 out = torch.isin(C, D) 8508 8509class TestSmoothL1Loss(TestCaseMPS): 8510 8511 def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False): 8512 # CPU 8513 input_cpu = torch.randn(4, 7, requires_grad=requires_grad) 8514 target_cpu = torch.randn(4, 7) 8515 8516 # MPS 8517 input_mps = input_cpu.detach().clone().to('mps').requires_grad_() 8518 target_mps = target_cpu.detach().clone().to('mps') 8519 8520 smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction) 8521 smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction) 8522 8523 self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps) 8524 8525 if requires_grad: 8526 smooth_l1_loss_cpu.backward() 8527 smooth_l1_loss_mps.backward() 8528 self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu")) 8529 8530 return smooth_l1_loss_cpu, smooth_l1_loss_mps 8531 8532 def test_smooth_l1_loss_reduction_none(self): 8533 self._smooth_l1_loss_helper(reduction="none") 8534 8535 def test_smooth_l1_loss_reduction_mean(self): 8536 self._smooth_l1_loss_helper(reduction="mean") 8537 8538 def test_smooth_l1_loss_reduction_sum(self): 8539 self._smooth_l1_loss_helper(reduction="sum") 8540 8541 def test_smooth_l1_loss_reduction_mean_backward(self): 8542 self._smooth_l1_loss_helper(reduction="mean", requires_grad=True) 8543 8544 def test_smooth_l1_loss_reduction_mean_sum_backward(self): 8545 self._smooth_l1_loss_helper(reduction="sum", requires_grad=True) 8546 8547class TestNLLLoss(TestCaseMPS): 8548 def test_nll_loss_mismatched_batch(self, device='mps'): 8549 x = torch.randn((10, 3), requires_grad=True, device=device) 8550 # t should have size (10,) 8551 t = torch.zeros((3,), dtype=torch.int64, device=device) 8552 with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'): 8553 F.nll_loss(x, t) 8554 8555 def test_nll_loss_out_of_bounds_ignore_index(self): 8556 8557 def test_nll_loss_out_of_bounds_ignore_index_helper(device): 8558 output = [] 8559 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [ 8560 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device) 8561 t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device) 8562 t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device) 8563 for reduction in ['mean', 'none']: 8564 # out of bound ignore_index 8565 output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction)) 8566 # default ignore_index 8567 output.append(F.nll_loss(x, t2, reduction=reduction)) 8568 return output 8569 8570 output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu') 8571 output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps') 8572 8573 for cpu, mps in zip(output_cpu, output_mps): 8574 self.assertEqual(cpu, mps) 8575 8576 def test_nll_loss_invalid_target_dim(self): 8577 8578 def _test_nll_loss_invalid_target_dim(device): 8579 output = [] 8580 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [ 8581 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device) 8582 t = torch.zeros((6, 2), dtype=torch.int64, device=device) 8583 with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"): 8584 F.nll_loss(x, t) 8585 8586 _test_nll_loss_invalid_target_dim(device='cpu') 8587 _test_nll_loss_invalid_target_dim(device='mps') 8588 8589 def test_nll_loss_invalid_weights(self): 8590 8591 def _test_nll_loss_invalid_weights(device): 8592 x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [ 8593 0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device) 8594 t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device) 8595 invalid_weights = [ 8596 torch.zeros(4, device=device), 8597 torch.zeros((1, 3), device=device), 8598 ] 8599 msg = "weight tensor should be defined either for all 3 classes or no classes" 8600 for weight in invalid_weights: 8601 with self.assertRaisesRegex(RuntimeError, msg): 8602 F.nll_loss(x, t, weight=weight) 8603 8604 _test_nll_loss_invalid_weights(device='cpu') 8605 _test_nll_loss_invalid_weights(device='mps') 8606 8607 def _nll_loss_helper(self, input_size, reduction, expected): 8608 8609 # CPU 8610 input = torch.rand(input_size, requires_grad=True, device='cpu') 8611 num_channels = input_size[1] 8612 target_size = (input_size[0], ) + tuple(input_size[2:]) 8613 target = torch.randint(num_channels, target_size, device='cpu') 8614 weights = torch.randn(num_channels) 8615 8616 # MPS 8617 input_mps = input.detach().clone().to('mps').requires_grad_() 8618 target_mps = target.detach().clone().to('mps') 8619 weights_mps = weights.to("mps") 8620 8621 output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction) 8622 output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction) 8623 self.assertEqual(output_cpu, output_mps.to('cpu')) 8624 8625 output_cpu.sum().backward() 8626 output_mps.sum().backward() 8627 self.assertEqual(input.grad, input_mps.grad.to('cpu')) 8628 8629 def _nll_loss_1d_helper(self, input_size, reduction): 8630 8631 # CPU 8632 input = torch.rand(input_size, requires_grad=True, device='cpu') 8633 num_channels = input_size[0] 8634 target = torch.randint(num_channels, [], device='cpu') 8635 8636 # MPS 8637 input_mps = input.detach().clone().to('mps').requires_grad_() 8638 target_mps = target.detach().clone().to('mps') 8639 8640 output_cpu = F.nll_loss(input, target, reduction=reduction) 8641 output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction) 8642 self.assertEqual(output_cpu, output_mps.to('cpu')) 8643 8644 output_cpu.sum().backward() 8645 output_mps.sum().backward() 8646 self.assertEqual(input.grad, input_mps.grad.to('cpu')) 8647 8648 def test_nll_loss_1d(self, device='cpu'): 8649 self._nll_loss_1d_helper([10], "none") 8650 self._nll_loss_1d_helper([10], "mean") 8651 self._nll_loss_1d_helper([10], "sum") 8652 8653 def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'): 8654 self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device)) 8655 self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device)) 8656 self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device)) 8657 self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device)) 8658 self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device)) 8659 8660 def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'): 8661 nan = torch.tensor(float('nan'), device=device) 8662 self._nll_loss_helper([1, 3], "mean", nan) 8663 self._nll_loss_helper([1, 3, 5, 7], "mean", nan) 8664 self._nll_loss_helper([2, 3, 1, 7], "mean", nan) 8665 self._nll_loss_helper([2, 3, 5, 1], "mean", nan) 8666 self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan) 8667 8668 def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'): 8669 zero = torch.tensor(0, device=device) 8670 self._nll_loss_helper([1, 3], "sum", zero) 8671 self._nll_loss_helper([1, 3, 5, 7], "sum", zero) 8672 self._nll_loss_helper([2, 3, 1, 7], "sum", zero) 8673 self._nll_loss_helper([2, 3, 5, 1], "sum", zero) 8674 self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero) 8675 8676 def test_nll_loss_byte_target_matches_long(self, device='cpu'): 8677 N, C = 10, 4 8678 input = torch.randn(N, C, device=device, requires_grad=True) 8679 target = torch.empty(N, dtype=torch.long, device=device).random_(0, C) 8680 8681 def compute_result_and_gradient(reduction, target_dtype): 8682 result, grad = {}, {} 8683 for dev in ['cpu', 'mps']: 8684 input_dev = input.to(dev) 8685 input_ = input_dev.detach() 8686 input_.requires_grad_() 8687 8688 target_dev = target.to(dev) 8689 8690 prob = F.log_softmax(input_, dim=-1) 8691 loss = nn.NLLLoss(reduction=reduction) 8692 result[dev] = loss(prob, target_dev.to(target_dtype)) 8693 result[dev].sum().backward() 8694 grad[dev] = input_.grad 8695 8696 return result, grad 8697 8698 for reduction in ["none", "mean", "sum"]: 8699 result_long, grad_long = compute_result_and_gradient(reduction, torch.long) 8700 result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8) 8701 8702 self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu']) 8703 self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu']) 8704 8705class TestTopK(TestCase): 8706 def _test_topk(self, shape, largest): 8707 cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) 8708 x = cpu_x.detach().clone().to('mps') 8709 if isinstance(shape, tuple): 8710 for curr_dim, dim_size in enumerate(shape): 8711 for k in range(1, dim_size + 1): 8712 topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest) 8713 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest) 8714 self.assertEqual(topk_values, topk_values_cpu) 8715 self.assertEqual(topk_indices, topk_indices_cpu) 8716 else: 8717 for k in range(1, shape): 8718 topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest) 8719 topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest) 8720 self.assertEqual(topk_values, topk_values_cpu) 8721 self.assertEqual(topk_indices, topk_indices_cpu) 8722 8723 def test_topk(self): 8724 largest_vals = [True, False] 8725 shapes = [ 8726 # Zero Element Tensors 8727 0, 8728 (1, 0), 8729 (0, 1), 8730 (1, 0, 1), 8731 # Multiple Element Tensors 8732 1, 8733 2, 8734 (5, 1), 8735 (1, 5), 8736 (5, 9, 7, 4), 8737 ] 8738 8739 for shape in shapes: 8740 for largest_val in largest_vals: 8741 with self.subTest(shape=shape, largest_val=largest_val): 8742 self._test_topk(shape, largest_val) 8743 8744class TestNNMPS(NNTestCase): 8745 8746 def _create_basic_net(self): 8747 class Layer(nn.Module): 8748 def __init__(self) -> None: 8749 super().__init__() 8750 self.layer_dummy_param = Parameter(torch.empty(3, 5)) 8751 self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7)) 8752 8753 class Net(nn.Module): 8754 def __init__(self) -> None: 8755 super().__init__() 8756 self.l1 = Layer() 8757 self.dummy_param = Parameter(torch.empty(3, 5)) 8758 self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1)) 8759 8760 l = Layer() 8761 n = Net() 8762 s = nn.Sequential(n, n) 8763 8764 return l, n, s 8765 8766 def test_requires_grad_(self): 8767 m = self._create_basic_net()[-1] 8768 assert len(list(m.buffers())) > 0, 'invalid test' 8769 assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test' 8770 assert len(list(m.parameters())) > 0, 'invalid test' 8771 assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test' 8772 for requires_grad in (False, True): 8773 self.assertIs(m.requires_grad_(requires_grad), m) 8774 for p in m.parameters(): 8775 self.assertEqual(p.requires_grad, requires_grad) 8776 for b in m.buffers(): 8777 self.assertFalse(b.requires_grad) 8778 8779 def test_module_backcompat(self): 8780 from torch.serialization import SourceChangeWarning 8781 path = download_file('https://download.pytorch.org/test_data/linear.pt') 8782 with warnings.catch_warnings(): 8783 warnings.simplefilter('ignore', SourceChangeWarning) 8784 m = torch.load(path) 8785 input = torch.randn(2, 3, dtype=torch.float) 8786 self.assertEqual(m(input).size(), (2, 5)) 8787 8788 def test_conv_backcompat(self): 8789 from torch.serialization import SourceChangeWarning 8790 # This file was generated by running on PyTorch 1.0.1 on Python 2: 8791 # 8792 # import torch 8793 # from torch import nn 8794 # m = nn.Conv2d(1, 1, 1) 8795 # torch.save(m, 'legacy_conv2d.pt') 8796 # 8797 # NB: This Pickle also contains some Unicode data! 8798 path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt') 8799 with warnings.catch_warnings(): 8800 warnings.simplefilter('ignore', SourceChangeWarning) 8801 m = torch.load(path, encoding='utf-8') 8802 input = torch.randn((1, 1, 1, 1), dtype=torch.float) 8803 self.assertEqual(m(input).size(), (1, 1, 1, 1)) 8804 8805 def test_conv_expand(self): 8806 device = 'mps' 8807 input_ = torch.rand(2, 3, 16, 16, device=device) 8808 kernel = torch.rand(1, 1, 3, 11, device=device) 8809 tmp_kernel = kernel.expand(-1, 3, -1, -1) 8810 output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1) 8811 8812 # The test should not crash 8813 def test_permute(self): 8814 M_cpu = torch.randn(5, 5) 8815 M_mps = M_cpu.to('mps') 8816 8817 output_cpu = M_cpu.permute(1, 0) 8818 output_mps = M_mps.permute(1, 0) 8819 8820 self.assertEqual(output_cpu, output_mps) 8821 self.assertEqual(output_cpu.size(), output_mps.size()) 8822 8823 # Printing of non_contiguous should not crash 8824 def test_print_non_contiguous(self): 8825 print(torch.ones(100, 100, device='mps').nonzero()) 8826 print(torch.ones(100, 100, device='mps').nonzero().contiguous()) 8827 8828 def test_zero_grad(self): 8829 i = torch.randn(2, 5, requires_grad=True) 8830 module = nn.Linear(5, 5) 8831 for p in module.parameters(): 8832 p.requires_grad = False 8833 module.zero_grad() 8834 8835 module.weight.requires_grad = True 8836 module.zero_grad() 8837 self.assertIsNone(module.weight.grad) # uninitialized grad 8838 8839 module(i).sum().backward() 8840 self.assertIsNotNone(module.weight.grad) 8841 self.assertGreater(module.weight.grad.data.abs().sum(), 0) 8842 module.zero_grad() 8843 self.assertIsNone(module.weight.grad) 8844 8845 module.bias.requires_grad = True 8846 module.zero_grad() 8847 self.assertIsNone(module.weight.grad) 8848 self.assertIsNone(module.bias.grad) 8849 module(i).sum().backward() 8850 self.assertIsNotNone(module.weight.grad) 8851 self.assertIsNotNone(module.bias.grad) 8852 self.assertGreater(module.weight.grad.data.abs().sum(), 0) 8853 self.assertGreater(module.bias.grad.data.abs().sum(), 0) 8854 8855 # Force set to zeros. 8856 module.zero_grad(set_to_none=False) 8857 self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_()) 8858 self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_()) 8859 8860 module.zero_grad() 8861 self.assertIsNone(module.weight.grad) 8862 self.assertIsNone(module.bias.grad) 8863 8864 8865 def test_no_grad(self): 8866 for dtype in [torch.bfloat16, torch.float, torch.double]: 8867 module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype) 8868 input = torch.randn(1, 2, 10, 10).to(dtype) 8869 x = input 8870 y = input.clone() 8871 8872 output = module(x) 8873 self.assertTrue(output.requires_grad) 8874 output.backward(torch.ones(1, 5, 10, 10)) 8875 8876 with torch.no_grad(): 8877 output2 = module(y) 8878 self.assertFalse(output2.requires_grad) 8879 self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10))) 8880 8881 def test_invalid_conv1d(self): 8882 for dtype in [torch.bfloat16, torch.float, torch.double]: 8883 module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype) 8884 input = torch.randn(1, 3, 4).to(dtype) 8885 with self.assertRaisesRegex(RuntimeError, 8886 r'Calculated padded input size per channel: \(4\). ' + 8887 r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'): 8888 module(input) 8889 8890 # Negative stride check 8891 module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype) 8892 input = torch.randn(1, 3, 4).to(dtype) 8893 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): 8894 module(input) 8895 8896 def test_conv2d_discontiguous_weight(self): 8897 # Test for https://github.com/pytorch/pytorch/issues/55781 8898 x = torch.ones(64, 16, 16, 16) 8899 weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2] 8900 self.assertFalse(weight.is_contiguous()) 8901 y = torch.nn.functional.conv2d(x, weight, None) 8902 if torch.backends.mkldnn.is_available(): 8903 # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used 8904 with torch.backends.mkldnn.flags(enabled=False): 8905 y_ = torch.nn.functional.conv2d(x, weight, None) 8906 self.assertEqual(y, y_) 8907 self.assertEqual(y.sum(), 4186112.) 8908 8909 def test_invalid_conv2d(self): 8910 for dtype in [torch.bfloat16, torch.float, torch.double]: 8911 module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) 8912 input = torch.empty(1, 1, 4, 4).to(dtype) 8913 self.assertRaises(RuntimeError, lambda: module(input)) 8914 8915 module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) 8916 input = torch.randn(1, 3, 1, 1) 8917 with self.assertRaisesRegex(RuntimeError, 8918 r'Calculated padded input size per channel: \(1 x 1\). ' + 8919 r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'): 8920 module(input) 8921 8922 # Negative stride check 8923 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype) 8924 input = torch.randn(1, 3, 4, 4).to(dtype) 8925 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): 8926 module(input) 8927 8928 # Zero stride check 8929 module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype) 8930 input = torch.randn(1, 3, 4, 4).to(dtype) 8931 with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'): 8932 module(input) 8933 8934 # Input and weights on different devices 8935 self.assertRaisesRegex(RuntimeError, 8936 'must be on the same device', 8937 lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps'))) 8938 self.assertRaisesRegex(RuntimeError, 8939 'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same', 8940 lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3))) 8941 8942 8943 def test_conv2d_valid_padding(self, device='mps'): 8944 # Test F.conv2d padding='valid' is the same as no padding 8945 x = torch.rand(1, 1, 1, 10, device=device).to(torch.float) 8946 y = torch.rand(1, 1, 1, 4, device=device).to(torch.float) 8947 8948 expect = F.conv2d(x, y) 8949 actual = F.conv2d(x, y, padding='valid') 8950 self.assertEqual(expect.to('cpu'), actual.to('cpu')) 8951 8952 def test_conv2d_backward_collision(self): 8953 # Test for https://github.com/pytorch/pytorch/issues/112998 8954 x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True) 8955 m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps") 8956 m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps") 8957 y1, y2 = m1(x), m2(x) 8958 self.assertEqual(y1.shape, y2.shape) 8959 y1.sum().backward() 8960 # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion 8961 y2.sum().backward() 8962 8963 @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12") 8964 def test_conv3d_backward_collision(self): 8965 # Conv3D is only available from MacOS 13.2 onwards 8966 x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True) 8967 m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps") 8968 m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps") 8969 y1, y2 = m1(x), m2(x) 8970 self.assertEqual(y1.shape, y2.shape) 8971 y1.sum().backward() 8972 # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion 8973 y2.sum().backward() 8974 8975 def test_gemm_permute_transpose(self): 8976 batch_size = 32 8977 n = 20 8978 hidden = 768 8979 num_attention_heads = 12 8980 attention_head_size = hidden // num_attention_heads 8981 8982 def transpose_for_scores(x: torch.Tensor) -> torch.Tensor: 8983 new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) 8984 x = x.view(new_x_shape) 8985 return x.permute(0, 2, 1, 3) 8986 8987 def attention2(key, *, workaround=False, device): 8988 key = transpose_for_scores(key) 8989 res = key.transpose(-1, -2) 8990 return res 8991 8992 A = torch.randn(batch_size, n, hidden) 8993 A_mps = A.detach().clone().to("mps") 8994 8995 r1 = attention2(A, device="cpu") 8996 r2 = attention2(A_mps, device="mps") 8997 8998 r2_cpu = r2.to("cpu") 8999 self.assertEqual(r1, r2_cpu) 9000 9001 def test_group_norm_backward(self, device='mps'): 9002 # See https://github.com/pytorch/pytorch/issues/88331 for more detail 9003 shape = [1, 4, 16, 16] 9004 x = torch.full(shape, 7.0, device=device) 9005 9006 target = torch.ones((1, 3, 128, 128), device=device) 9007 9008 conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device) 9009 conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device) 9010 norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device) 9011 9012 with torch.enable_grad(): 9013 x = x.detach().requires_grad_() 9014 out = 5.5 * x 9015 out = conv_in(out) 9016 out = out + norm(out) 9017 out = out + norm(out) 9018 out = out + norm(out) 9019 out = F.interpolate(out, scale_factor=8.0, mode="nearest") 9020 out = norm(out) 9021 out = conv_out(out) 9022 9023 loss = (out - target).norm(dim=-1).sum() 9024 grad = -torch.autograd.grad(loss, x)[0] 9025 self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd') 9026 9027 9028 # def test_conv2d_same_padding(self, device='mps'): 9029 # x = torch.rand(1, 1, 10, 11, device=device) 9030 # y = torch.rand(1, 1, 4, 5, device=device) 9031 # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :] 9032 # actual = F.conv2d(x, y, padding='same') 9033 # self.assertEqual(expect.to('cpu'), actual.to('cpu')) 9034 9035 # # With dilation 9036 # y = torch.rand(1, 1, 3, 4, device=device) 9037 # expect = F.conv2d(x, y, padding=(2, 3), dilation=2) 9038 # actual = F.conv2d(x, y, padding='same', dilation=2) 9039 # self.assertEqual(expect, actual) 9040 9041 # # Dilation with asymmetric padding 9042 # y = torch.rand(1, 1, 4, 4, device=device) 9043 # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:] 9044 # actual = F.conv2d(x, y, padding='same', dilation=3) 9045 # self.assertEqual(expect, actual) 9046 9047 9048class TestPad(TestCaseMPS): 9049 def test_constant_pad(self): 9050 m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5) 9051 input_cpu = torch.randn(1, 16, 16, 16) 9052 input_mps = input_cpu.detach().clone().to("mps") 9053 r_cpu = m(input_cpu) 9054 r_mps = m(input_mps) 9055 self.assertEqual(r_cpu, r_mps.to("cpu")) 9056 9057 # Arbitrary input dimensions 9058 pad = (1, 1, 0, 0, 0, 0) 9059 value = 3.5 9060 input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3)) 9061 input_mps = input_cpu.detach().clone().to("mps") 9062 r_cpu = F.pad(input_cpu, pad=pad, value=value) 9063 r_mps = F.pad(input_mps, pad=pad, value=value) 9064 self.assertEqual(r_cpu, r_mps.to("cpu")) 9065 9066 def test_circular_pad(self): 9067 # https://github.com/pytorch/pytorch/issues/80856 9068 k_cpu = torch.ones(3, 3, 9, 9) 9069 k_mps = k_cpu.detach().clone().to("mps") 9070 9071 x_cpu = torch.rand(1, 3, 32, 32) 9072 x_mps = x_cpu.detach().clone().to("mps") 9073 9074 x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular') 9075 x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular') 9076 9077 y_cpu = F.conv2d(x_pad_cpu, k_cpu) 9078 y_mps = F.conv2d(x_pad_mps, k_mps) 9079 9080 self.assertEqual(y_cpu, y_mps.cpu()) 9081 9082 def test_constant_pad_4d_warning(self): 9083 inputCPU = torch.rand((1, 2, 2, 2, 1, 1)) 9084 inputMPS = inputCPU.detach().clone().to('mps') 9085 outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0]) 9086 outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0]) 9087 self.assertEqual(outputCPU, outputMPS) 9088 9089 def test_pad(self): 9090 def helper(shape, padding, op, value=0): 9091 inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) 9092 inputCPU.retain_grad() 9093 inputMPS = inputCPU.detach().clone().to('mps').requires_grad_() 9094 9095 if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]): 9096 padCriteria = op(padding, value) 9097 else: 9098 padCriteria = op(padding) 9099 outputCPU = padCriteria(inputCPU) 9100 outputMPS = padCriteria(inputMPS) 9101 self.assertEqual(outputCPU, outputMPS) 9102 9103 # backward pass (chose 0.6 just to have the grad_output != 1) 9104 outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6)) 9105 outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6)) 9106 self.assertEqual(inputCPU.grad, inputMPS.grad) 9107 9108 # 1D Padding 9109 helper((2, 4, 3), 2, nn.ReflectionPad1d) 9110 # verify if a change in shape of input would cause problems with graph caching 9111 helper((2, 4, 4), (1, 3), nn.ReflectionPad1d) 9112 # Replication 1D 9113 helper((2, 1, 6), 3, nn.ReplicationPad1d) 9114 # Constant Pad 1D 9115 helper((2, 3, 4), 2, nn.ConstantPad1d) 9116 # Constant Pad 1D with single dimension input 9117 helper((16), (1, 2), nn.ConstantPad1d) 9118 9119 # 2D Padding 9120 helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d) 9121 # verify if a change in shape of input would cause problems with graph caching 9122 helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d) 9123 # this should make the padding (2, 2, 2, 2) 9124 helper((2, 1, 6, 8), 2, nn.ReplicationPad2d) 9125 # verify if a change in shape of padding would cause problems with graph caching 9126 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d) 9127 # Constant Pad 2D 9128 helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d) 9129 # input size < pad size 9130 helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d) 9131 # pad dims < input dims 9132 helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d) 9133 # pad dims == input dims 9134 helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d) 9135 # input.numel() == 0 but output.numel() > 0 9136 helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d) 9137 # pad dims < input dims - 2 9138 helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d) 9139 9140 # 3D Padding 9141 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d) 9142 # verify if a change in shape of padding would cause problems with graph caching 9143 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d) 9144 # case where input_d == pad_front/back for ReplicationPad3d 9145 helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d) 9146 # Constant Pad 3D 9147 helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) 9148 # input size < pad size 9149 helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d) 9150 # check the workaround for the right padding bug in Monterey 9151 helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d) 9152 9153 def test_constant_pad_nd_preserves_memory_format(self): 9154 nchw_tensor = torch.rand((1, 2, 5, 3)) 9155 nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5) 9156 self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format)) 9157 9158 nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last) 9159 nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5) 9160 self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last)) 9161 9162 9163class TestLinalgMPS(TestCaseMPS): 9164 def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False): 9165 dtype = t.dtype 9166 numpy_dtype = dtype 9167 alpha = 1.2 if alpha is None else alpha 9168 beta = 0.8 if beta is None else beta 9169 res1 = f(t, m, v, alpha=alpha, beta=beta) 9170 res2 = torch.full_like(res1, math.nan) 9171 if transpose_out: 9172 res2 = res2.t().clone(memory_format=torch.contiguous_format).t() 9173 f(t, m, v, alpha=alpha, beta=beta, out=res2) 9174 res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) 9175 if beta != 0: 9176 res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy() 9177 res3 = torch.from_numpy(res3).to(dtype) 9178 self.assertEqual(res1, res2) 9179 self.assertEqual(res1, res3) 9180 9181 def test_addmm(self, device="mps", dtype=torch.float32): 9182 M = torch.randn(10, 25, device=device).to(dtype) 9183 m1 = torch.randn(10, 50, device=device).to(dtype) 9184 m2 = torch.randn(50, 25, device=device).to(dtype) 9185 self._test_addmm_addmv(torch.addmm, M, m1, m2) 9186 9187 # Test beta=0, M=nan 9188 M = torch.full((10, 25), math.nan, device=device).to(dtype) 9189 m1 = torch.randn(10, 50, device=device).to(dtype) 9190 m2 = torch.randn(50, 25, device=device).to(dtype) 9191 self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0) 9192 9193 # Test transpose 9194 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 9195 def maybe_transpose(cond, m): 9196 if not cond: 9197 return m 9198 return m.t().clone(memory_format=torch.contiguous_format).t() 9199 9200 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 9201 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 9202 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 9203 self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4) 9204 9205 def _test_addr(self, f, t, m, v, alpha=None, beta=None): 9206 dtype = t.dtype 9207 numpy_dtype = dtype 9208 alpha = 1.2 if alpha is None else alpha 9209 beta = 0.8 if beta is None else beta 9210 res1 = f(t, m, v, alpha=alpha, beta=beta) 9211 res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy()) 9212 if beta != 0: 9213 res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy() 9214 res2 = torch.from_numpy(res2).to(dtype) 9215 self.assertEqual(res1, res2) 9216 9217 def test_addr(self, device="mps", dtype=torch.float32): 9218 M = torch.randn(10, 25, device=device).to(dtype) 9219 m1 = torch.randn(10, device=device).to(dtype) 9220 m2 = torch.randn(25, device=device).to(dtype) 9221 self._test_addr(torch.addr, M, m1, m2) 9222 9223 # Test beta=0, M=nan 9224 M = torch.full((10, 25), math.nan, device=device).to(dtype) 9225 m1 = torch.randn(10, device=device).to(dtype) 9226 m2 = torch.randn(25, device=device).to(dtype) 9227 self._test_addr(torch.addr, M, m1, m2, beta=0) 9228 9229 def test_matrix_rank(self, device="mps", dtype=torch.float32): 9230 matrix_rank = torch.linalg.matrix_rank 9231 9232 def run_test(shape0, shape1, batch): 9233 a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device) 9234 rank_a = matrix_rank(a) 9235 9236 self.assertEqual(rank_a, matrix_rank(a.mH)) 9237 aaH = torch.matmul(a, a.mH) 9238 rank_aaH = matrix_rank(aaH) 9239 rank_aaH_hermitian = matrix_rank(aaH, hermitian=True) 9240 self.assertEqual(rank_aaH, rank_aaH_hermitian) 9241 aHa = torch.matmul(a.mH, a) 9242 self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True)) 9243 9244 # check against NumPy 9245 self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy())) 9246 self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01)) 9247 9248 self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy())) 9249 self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01)) 9250 9251 # hermitian flag for NumPy was added in 1.14.0 9252 if np.lib.NumpyVersion(np.__version__) >= '1.14.0': 9253 self.assertEqual(rank_aaH_hermitian, 9254 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True)) 9255 self.assertEqual(matrix_rank(aaH, 0.01, True), 9256 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True)) 9257 9258 # check out= variant 9259 out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device) 9260 ans = matrix_rank(a, out=out) 9261 self.assertEqual(ans, out) 9262 self.assertEqual(ans, rank_a) 9263 9264 shapes = (3, 13) 9265 batches = ((), (0, ), (4, ), (3, 5, )) 9266 for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches): 9267 # escape only when NotImplementedError of downstream function is raised 9268 # TODO: remove this once the required function is implemented 9269 try: 9270 run_test(shape0, shape1, batch) 9271 except NotImplementedError as e: 9272 with self.assertRaisesRegex( 9273 NotImplementedError, 9274 "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."): 9275 raise e 9276 9277 def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4): 9278 from torch.testing._internal.common_utils import random_hermitian_pd_matrix 9279 9280 def run_test_main(A, hermitian): 9281 # Testing against definition for pseudo-inverses 9282 A_pinv = torch.linalg.pinv(A, hermitian=hermitian) 9283 np_A = A.cpu().numpy() 9284 np_A_pinv = A_pinv.cpu().numpy() 9285 if A.numel() > 0: 9286 self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision) 9287 self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision) 9288 self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision) 9289 self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision) 9290 else: 9291 self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2])) 9292 9293 # Check out= variant 9294 out = torch.empty_like(A_pinv) 9295 ans = torch.linalg.pinv(A, hermitian=hermitian, out=out) 9296 self.assertEqual(ans, out) 9297 self.assertEqual(ans, A_pinv) 9298 9299 def run_test_numpy(A, hermitian): 9300 # Check against NumPy output 9301 # Test float rcond, and specific value for each matrix 9302 rconds = [float(torch.rand(1)), ] 9303 # Test different types of rcond tensor 9304 for rcond_type in MPS_DTYPES: 9305 rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type)) 9306 # Test broadcasting of rcond 9307 if A.ndim > 2: 9308 rconds.append(torch.rand(A.shape[-3], device=device)) 9309 for rcond in rconds: 9310 actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian) 9311 torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian) 9312 self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision) 9313 numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy() 9314 expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian) 9315 self.assertEqual(actual, expected, atol=precision, rtol=precision) 9316 9317 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 9318 (3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices 9319 (2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices 9320 (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices 9321 A = torch.randn(*sizes, dtype=dtype, device=device) 9322 hermitian = False 9323 run_test_main(A, hermitian) 9324 run_test_numpy(A, hermitian) 9325 9326 # Check hermitian = True 9327 for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices 9328 (0, 0), (3, 0, 0), ]: # zero numel square matrices 9329 A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device) 9330 hermitian = True 9331 # escape only when NotImplementedError of downstream function is raised 9332 # TODO: remove this once the required function is implemented 9333 try: 9334 run_test_main(A, hermitian) 9335 except NotImplementedError as e: 9336 with self.assertRaisesRegex( 9337 NotImplementedError, 9338 "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."): 9339 raise e 9340 try: 9341 run_test_numpy(A, hermitian) 9342 except NotImplementedError as e: 9343 with self.assertRaisesRegex( 9344 NotImplementedError, 9345 "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."): 9346 raise e 9347 9348 @parametrize("m", [1, 32, 64]) 9349 @parametrize("n", [48, 64]) 9350 @parametrize("q_group", [32, 64, 128, 256]) 9351 @parametrize("num_groups", [1, 2]) 9352 def test__int4_mm(self, m, n, q_group, num_groups): 9353 k = q_group * num_groups 9354 inner_k_tiles = 2 9355 9356 torch.manual_seed(1) 9357 a_f32 = torch.rand((m, k), device="mps") 9358 b_f32 = torch.rand((k, n), device="mps") 9359 9360 def convert_weight_to_int4pack(b): 9361 b_int32, b_scales_and_zeros = _group_quantize_tensor( 9362 b.to("cpu"), n_bit=4, q_group_size=q_group 9363 ) 9364 b_int32 = b_int32.to("mps") 9365 b_scales_and_zeros = b_scales_and_zeros.to("mps") 9366 b_int4pack = torch._convert_weight_to_int4pack( 9367 b_int32, inner_k_tiles 9368 ) 9369 9370 return b_int4pack, b_scales_and_zeros 9371 9372 def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): 9373 return torch._weight_int4pack_mm( 9374 a, b_int4pack, q_group, b_scales_and_zeros 9375 ) 9376 9377 b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32) 9378 9379 for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []): 9380 a = a_f32.to(dtype=dtype) 9381 b = b_f32.to(dtype=dtype) 9382 b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype) 9383 ref = torch.mm(a, b) 9384 res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) 9385 9386 mean_err = ((res - ref).abs() / ref).mean() 9387 self.assertLess(mean_err, 0.05) 9388 9389 @parametrize("m", [1, 32, 64]) 9390 @parametrize("k", [32, 64]) 9391 @parametrize("n", [32, 64]) 9392 def test__int8_mm(self, m, k, n): 9393 torch.manual_seed(1) 9394 a_f32 = torch.rand((m, k), device="mps") 9395 b_f32 = torch.rand((n, k), device="mps") 9396 9397 def convert_weight_to_int8pack(b): 9398 b_int8pack, b_scales, _ = _dynamically_quantize_per_channel( 9399 b, -128, 127, torch.int8 9400 ) 9401 return b_int8pack, b_scales 9402 9403 def weight_int8pack_mm(a, b_int8pack, b_scales): 9404 return torch._weight_int8pack_mm(a, b_int8pack, b_scales) 9405 9406 b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32) 9407 for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []): 9408 a = a_f32.to(dtype=dtype) 9409 b = b_f32.to(dtype=dtype) 9410 b_scales = b_scales_f32.to(dtype=dtype) 9411 res = weight_int8pack_mm(a, b_int8pack, b_scales) 9412 ref = torch.mm(a, b.transpose(0, 1)) 9413 9414 mean_err = ((res - ref).abs() / ref).mean() 9415 self.assertLess(mean_err, 0.05) 9416 9417 9418class TestSDPA(TestCaseMPS): 9419 def _compare_tensors(self, y, ref): 9420 denom = torch.maximum(ref.abs(), torch.tensor([1e-6], device=ref.device, dtype=ref.dtype)) 9421 err = ((y - ref).abs() / denom).mean().item() 9422 self.assertLess(err, 0.01) 9423 9424 def _test_sdpa_no_mask( 9425 self, 9426 is_causal: bool, 9427 dtype: torch.dtype, 9428 L: int = 1, 9429 S: int = 72, 9430 NH: int = 32, 9431 HS: int = 128, 9432 requires_grad: bool = False 9433 ): 9434 9435 torch.manual_seed(1729) 9436 with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): 9437 q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad) 9438 k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9439 v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9440 q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad) 9441 k_cpu = k.cpu() 9442 v_cpu = v.cpu() 9443 9444 y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal) 9445 y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal) 9446 9447 self._compare_tensors(y.cpu(), y_ref) 9448 9449 if requires_grad and torch.is_grad_enabled(): 9450 y.sum().backward() 9451 y_ref.sum().backward() 9452 9453 self._compare_tensors(q.grad.cpu(), q_cpu.grad) 9454 9455 def test_sdpa_no_mask_no_causal_fp32(self): 9456 self._test_sdpa_no_mask(False, torch.float32) 9457 9458 def test_sdpa_no_mask_no_causal_fp16(self): 9459 self._test_sdpa_no_mask(False, torch.float16) 9460 9461 def test_sdpa_no_mask_causal_fp32(self): 9462 self._test_sdpa_no_mask(True, torch.float32) 9463 9464 def test_sdpa_no_mask_causal_fp16(self): 9465 self._test_sdpa_no_mask(True, torch.float16) 9466 9467 def test_sdpa_no_mask_causal_fp16_L7(self): 9468 self._test_sdpa_no_mask(True, torch.float16, 7) 9469 9470 def test_sdpa_no_mask_causal_fp16_L7_S17(self): 9471 self._test_sdpa_no_mask(True, torch.float16, 7, 17) 9472 9473 def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self): 9474 self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121) 9475 9476 def test_sdpa_no_mask_no_causal_fp32_grad(self): 9477 self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) 9478 9479 with torch.no_grad(): 9480 self._test_sdpa_no_mask(False, torch.float32, requires_grad=True) 9481 9482 def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128): 9483 torch.manual_seed(1729) 9484 causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps')) 9485 with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): 9486 i = 42 9487 9488 q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps") 9489 k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9490 v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") 9491 9492 input_pos = torch.tensor([i], dtype=torch.int32, device='mps') 9493 mask = causal_mask[None, None, input_pos] 9494 9495 y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) 9496 y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False) 9497 9498 self._compare_tensors(y.cpu(), y_ref) 9499 9500 def test_sdpa_mask_fp32(self): 9501 self._test_sdpa_mask(torch.float32) 9502 9503 def test_sdpa_mask_fp16(self): 9504 self._test_sdpa_mask(torch.float16) 9505 9506 def test_sdpa_mask_fp16_L6(self): 9507 self._test_sdpa_mask(torch.float16, 6) 9508 9509 def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self): 9510 self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) 9511 9512 9513class TestGatherScatter(TestCaseMPS): 9514 def test_slicing_with_step(self): 9515 # Slicing with step 9516 # https://github.com/pytorch/pytorch/issues/78886 9517 x_mps = torch.zeros(10, dtype=torch.float32, device="mps") 9518 x_mps[::2] = 1.0 9519 9520 x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu") 9521 x_cpu[::2] = 1.0 9522 9523 self.assertEqual(x_cpu, x_mps) 9524 9525 def test_cast_gather_scatter(self): 9526 for _ in range(0, 50): 9527 input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8) 9528 with torch.no_grad(): 9529 s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0) 9530 s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0) 9531 s = s.long() 9532 s_cpu = s_cpu.long() 9533 self.assertEqual(s.cpu(), s_cpu) 9534 9535 s = s.float() 9536 s_cpu = s_cpu.float() 9537 self.assertEqual(s.cpu(), s_cpu) 9538 9539 s /= 255 9540 s_cpu /= 255 9541 self.assertEqual(s.cpu(), s_cpu) 9542 9543 def test_slicing_replace_column(self): 9544 # https://github.com/pytorch/pytorch/issues/78074 9545 def _helper(tensor_data): 9546 x_cpu = torch.tensor(tensor_data) 9547 x_mps = x_cpu.to('mps') 9548 9549 x_cpu[:, 0] = 7 9550 x_mps[:, 0] = 7 9551 9552 self.assertEqual(x_cpu, x_mps) 9553 9554 _helper([[1, 2, 3], [4, 5, 6]]) 9555 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 9556 _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) 9557 9558 def test_inplace_scatter(self): 9559 # https://github.com/pytorch/pytorch/issues/79672 9560 a_mps = torch.ones((2, 2),).to(torch.device("mps")) 9561 b_mps = torch.ones((2, 2),).to(torch.device("mps")) 9562 9563 a_cpu = torch.ones((2, 2),).to(torch.device("cpu")) 9564 b_cpu = torch.ones((2, 2),).to(torch.device("cpu")) 9565 9566 a_mps[:, 0] += b_mps[:, 0] 9567 a_cpu[:, 0] += b_cpu[:, 0] 9568 self.assertEqual(a_cpu, a_mps) 9569 9570 a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0] 9571 a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0] 9572 self.assertEqual(a_cpu, a_mps) 9573 9574# These tests were taken from test/test_view_ops.py 9575# They are subset of those tests as currently only this subset is working. 9576# This whole `class` will be removed when we add generic device testing. There 9577# are no additional tests added apart from what is part of test_view_ops.py 9578class TestViewOpsMPS(TestCaseMPS): 9579 exact_dtype = True 9580 9581 def test_permute_slicing(self): 9582 # test the fix for crash reported in 9583 # https://github.com/pytorch/pytorch/issues/94190 9584 cpu_x = (torch.randn([3, 2, 2]).float()) 9585 mps_x = cpu_x.detach().clone().to('mps') 9586 cpu_out = cpu_x.permute((2, 0, 1)) * 2.0 9587 mps_out = mps_x.permute((2, 0, 1)) * 2.0 9588 # this print caused a crash prior to fix PR#94259 9589 print(torch.zeros_like(mps_out)) 9590 # test the fix for fill_scalar_mps() mentioned in issue #94190 9591 self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out)) 9592 self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1)) 9593 9594 def is_view_of(self, base, other): 9595 if (not other._is_view() or 9596 other is base or 9597 other._base is not base or 9598 base.device != other.device): 9599 return False 9600 # Note: only validates storage on native device types 9601 # because some accelerators, like XLA, do not expose storage 9602 if base.device.type == 'mps': 9603 if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr(): 9604 return False 9605 9606 return True 9607 9608 # Returns true if v1 and v2 are views of the same base 9609 def is_view_of_same_base(self, v1, v2): 9610 if (not v1._is_view() or v1 is v2): 9611 return False 9612 return self.is_view_of(v1._base, v2) 9613 9614 # Performs transpose if contiguous=True, else returns the input tensor as is 9615 def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): 9616 if contiguous: 9617 return x 9618 else: 9619 return x.transpose(dim0, dim1) 9620 9621 def test_diagonal_view(self, device="mps"): 9622 t = torch.ones((5, 5), device=device) 9623 v = torch.diagonal(t) 9624 self.assertTrue(self.is_view_of(t, v)) 9625 9626 v[0] = 0 9627 self.assertEqual(t[0, 0], v[0]) 9628 9629 t = torch.ones((3, 3, 3), device="mps") 9630 v = torch.diagonal(t, offset=1, dim1=1, dim2=2) 9631 self.assertTrue(self.is_view_of(t, v)) 9632 9633 v[0, 0] = 0 9634 self.assertEqual(t[0, 0, 1], v[0, 0]) 9635 9636 def test_select_view(self, device="mps") -> None: 9637 t = torch.ones((5, 5), device=device) 9638 v = t.select(0, 2) 9639 self.assertTrue(self.is_view_of(t, v)) 9640 9641 v[0] = 0 9642 self.assertEqual(t[2, 0], v[0]) 9643 9644 def test_unbind_view(self, device="mps") -> None: 9645 t = torch.zeros((5, 5), device=device) 9646 tup = torch.unbind(t) 9647 9648 for idx, v in enumerate(tup): 9649 self.assertTrue(self.is_view_of(t, v)) 9650 9651 v[0] = idx + 1 9652 self.assertEqual(t[idx, 0], v[0]) 9653 9654 def test_expand_view(self, device="mps") -> None: 9655 t = torch.ones((5, 1), device=device) 9656 v = t.expand(5, 5) 9657 self.assertTrue(self.is_view_of(t, v)) 9658 9659 v[2, 2] = 0 9660 self.assertEqual(t[2, 0], v[2, 2]) 9661 9662 def test_expand_as_view(self, device="mps"): 9663 t = torch.ones((5, 1), device=device) 9664 e = torch.empty((5, 5), device=device) 9665 v = t.expand_as(e) 9666 self.assertTrue(self.is_view_of(t, v)) 9667 9668 v[2, 2] = 0 9669 self.assertEqual(t[2, 0], v[2, 2]) 9670 9671 def test_narrow_view(self, device="mps"): 9672 t = torch.ones((5, 5), device=device) 9673 v = torch.narrow(t, 1, 2, 2) 9674 self.assertTrue(self.is_view_of(t, v)) 9675 9676 v[0, 0] = 0 9677 self.assertEqual(t[0, 2], v[0, 0]) 9678 9679 def test_permute_view(self, device="mps") -> None: 9680 t = torch.ones((5, 5), device=device) 9681 v = t.permute(1, 0) 9682 self.assertTrue(self.is_view_of(t, v)) 9683 9684 v[0, 1] = 0 9685 self.assertEqual(t[1, 0], v[0, 1]) 9686 9687 def test_transpose_view(self, device="mps"): 9688 for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 9689 t = torch.ones((5, 5), device=device) 9690 v = fn(t, 0, 1) 9691 self.assertTrue(self.is_view_of(t, v)) 9692 9693 v[0, 1] = 0 9694 self.assertEqual(t[1, 0], v[0, 1]) 9695 9696 def test_transpose_inplace_view(self, device="mps"): 9697 t = torch.ones(5, 5, device=device) 9698 v = t.view_as(t) 9699 v = v.swapdims_(0, 1) 9700 self.assertTrue(self.is_view_of(t, v)) 9701 v[0, 1] = 0 9702 self.assertEqual(t[1, 0], v[0, 1]) 9703 9704 t = torch.ones(5, 5, device=device) 9705 v = t.view_as(t) 9706 v = v.swapaxes_(0, 1) 9707 self.assertTrue(self.is_view_of(t, v)) 9708 v[0, 1] = 0 9709 self.assertEqual(t[1, 0], v[0, 1]) 9710 9711 t = torch.ones(5, 5, device=device) 9712 v = t.view_as(t) 9713 v = v.transpose_(0, 1) 9714 self.assertTrue(self.is_view_of(t, v)) 9715 v[0, 1] = 0 9716 self.assertEqual(t[1, 0], v[0, 1]) 9717 9718 def test_t_view(self, device="mps"): 9719 t = torch.ones((5, 5), device=device) 9720 v = t.t() 9721 self.assertTrue(self.is_view_of(t, v)) 9722 9723 v[0, 1] = 0 9724 self.assertEqual(t[1, 0], v[0, 1]) 9725 9726 def test_inplace_view_add(self): 9727 # https://github.com/pytorch/pytorch/issues/96153 9728 t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3) 9729 t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3) 9730 t_mps = t_mps + 1 9731 t_cpu = t_cpu + 1 9732 self.assertEqual(t_mps, t_cpu) 9733 9734 def test_t_inplace_view(self, device="mps"): 9735 t = torch.ones(5, 5, device=device) 9736 v = t.view_as(t) 9737 v = v.t_() 9738 self.assertTrue(self.is_view_of(t, v)) 9739 v[0, 1] = 0 9740 self.assertEqual(t[1, 0], v[0, 1]) 9741 9742 def test_T_view(self, device="mps"): 9743 for op in ("T", "H", "mT", "mH"): 9744 t = torch.ones((5, 5), device=device) 9745 v = getattr(t, op) 9746 self.assertTrue(self.is_view_of(t, v)) 9747 9748 v[0, 1] = 0 9749 self.assertEqual(t[1, 0], v[0, 1]) 9750 9751 def test_unfold_view(self, device="mps"): 9752 t = torch.ones(10, device=device) 9753 v = t.unfold(0, 3, 2) 9754 self.assertTrue(self.is_view_of(t, v)) 9755 9756 v[1, 0] = 0 9757 self.assertEqual(t[2], v[1, 0]) 9758 9759 def test_squeeze_view(self, device="mps"): 9760 t = torch.ones(5, 1, 5, device=device) 9761 v = torch.squeeze(t) 9762 self.assertTrue(self.is_view_of(t, v)) 9763 v[0, 1] = 0 9764 self.assertIs(t, v._base) 9765 9766 def test_squeeze_inplace_view(self, device="mps"): 9767 t = torch.ones(5, 5, device=device) 9768 v = t.view_as(t) 9769 v = v.squeeze_() 9770 self.assertTrue(self.is_view_of(t, v)) 9771 v[0, 1] = 0 9772 self.assertIs(t, v._base) 9773 9774 def test_unsqueeze_view(self, device="mps"): 9775 t = torch.ones(5, 5, device=device) 9776 v = torch.unsqueeze(t, 1) 9777 self.assertTrue(self.is_view_of(t, v)) 9778 9779 v[0, 0, 1] = 0 9780 self.assertEqual(t[0, 1], v[0, 0, 1]) 9781 9782 def test_unsqueeze_inplace_view(self, device="mps"): 9783 t = torch.ones(5, 5, device=device) 9784 v = t.view_as(t) 9785 v = v.unsqueeze_(1) 9786 self.assertTrue(self.is_view_of(t, v)) 9787 v[0, 0, 1] = 0 9788 self.assertEqual(t[0, 1], v[0, 0, 1]) 9789 9790 def test_as_strided_view(self, device="mps"): 9791 t = torch.ones(5, 5, device=device) 9792 v = torch.as_strided(t, (25,), (1,)) 9793 self.assertTrue(self.is_view_of(t, v)) 9794 9795 v[6] = 0 9796 self.assertEqual(t[1, 1], v[6]) 9797 9798 def test_as_strided_inplace_view(self, device="mps"): 9799 t = torch.ones(5, 5, device=device) 9800 v = t.view_as(t) 9801 v = v.as_strided_((25,), (1,)) 9802 self.assertTrue(self.is_view_of(t, v)) 9803 v[6] = 0 9804 self.assertEqual(t[1, 1], v[6]) 9805 9806 def test_view_view(self, device="mps"): 9807 t = torch.ones(5, 5, device=device) 9808 v = t.view(25) 9809 self.assertTrue(self.is_view_of(t, v)) 9810 9811 v[6] = 0 9812 self.assertEqual(t[1, 1], v[6]) 9813 9814 def test_view_as_view(self, device="mps"): 9815 t = torch.ones(5, 5, device=device) 9816 e = torch.empty((25,)) 9817 v = t.view_as(e) 9818 self.assertTrue(self.is_view_of(t, v)) 9819 9820 v[6] = 0 9821 self.assertEqual(t[1, 1], v[6]) 9822 9823 def test_contiguous_self(self, device="mps"): 9824 t = torch.ones(5, 5, device=device) 9825 s = t.contiguous() 9826 self.assertIs(s, t) 9827 9828 def test_contiguous_nonview(self, device="mps"): 9829 t = torch.ones(5, 5, device=device) 9830 nv = t.t().contiguous() 9831 self.assertFalse(self.is_view_of(t, nv)) 9832 9833 nv[0, 0] = 0 9834 self.assertNotEqual(t[0, 0], nv[0, 0]) 9835 9836 def test_reshape_view(self, device="mps"): 9837 t = torch.ones(5, 5, device=device) 9838 v = torch.reshape(t, (25,)) 9839 self.assertTrue(self.is_view_of(t, v)) 9840 9841 v[6] = 0 9842 self.assertEqual(t[1, 1], v[6]) 9843 9844 def test_reshape_as_view(self, device="mps"): 9845 t = torch.ones(5, 5, device=device) 9846 e = torch.empty((25,), device=device) 9847 v = t.reshape_as(e) 9848 self.assertTrue(self.is_view_of(t, v)) 9849 9850 v[6] = 0 9851 self.assertEqual(t[1, 1], v[6]) 9852 9853 def test_reshape_nonview(self, device="mps"): 9854 t = torch.ones(5, 5, device=device) 9855 nv = torch.reshape(t.t(), (25,)) 9856 self.assertFalse(self.is_view_of(t, nv)) 9857 9858 nv[6] = 0 9859 self.assertNotEqual(t[1, 1], nv[6]) 9860 9861 def test_flatten_view(self, device="mps"): 9862 def test_writes_propagate(t, v): 9863 idx_t = (0,) * t.ndim 9864 idx_v = (0,) * v.ndim 9865 v[idx_v] = 0 9866 self.assertEqual(t[idx_t], v[idx_v]) 9867 9868 t = torch.ones(1, 2, 3, 4, device=device) 9869 v = t.flatten() 9870 self.assertTrue(self.is_view_of(t, v)) 9871 test_writes_propagate(t, v) 9872 9873 # zero-dimensional tensor 9874 t = torch.tensor(1, device=device) 9875 v = t.flatten() 9876 test_writes_propagate(t, v) 9877 self.assertTrue(self.is_view_of(t, v)) 9878 9879 t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) 9880 v = t.flatten(0, 1) 9881 test_writes_propagate(t, v) 9882 self.assertTrue(self.is_view_of_same_base(t, v)) 9883 9884 # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: 9885 t = torch.ones(720, device=device) \ 9886 .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)) 9887 # [--1--|---2---|-3-] [--1--|----2---|-3-] 9888 v1 = t.flatten(0, 1) 9889 v2 = v1.flatten(1, 3) 9890 v3 = v2.flatten(2, 2) 9891 test_writes_propagate(t, v1) 9892 self.assertTrue(self.is_view_of_same_base(t, v1)) 9893 test_writes_propagate(t, v2) 9894 self.assertTrue(self.is_view_of_same_base(t, v2)) 9895 test_writes_propagate(t, v3) 9896 self.assertTrue(self.is_view_of_same_base(t, v3)) 9897 9898 def test_flatten_nonview(self, device="mps"): 9899 def assert_is_nonview(t, nv): 9900 idx_t = (0,) * t.ndim 9901 idx_nv = (0,) * nv.ndim 9902 self.assertFalse(nv._is_view()) 9903 nv[idx_nv] = 0 9904 self.assertNotEqual(t[idx_t], nv[idx_nv]) 9905 t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) 9906 nv = t.flatten(1, 3) 9907 assert_is_nonview(t, nv) 9908 9909 t = torch.ones(2, 2, device=device).T 9910 nv = t.flatten() 9911 assert_is_nonview(t, nv) 9912 9913 # flatten returns the original object if start_dim=end_dim 9914 t = t = torch.ones(2, 2, device=device) 9915 nv = t.flatten(1, 1) 9916 self.assertIs(t, nv) 9917 9918 def test_basic_indexing_slice_view(self, device="mps"): 9919 t = torch.ones(5, 5, device=device) 9920 v = t[:2, :3] 9921 self.assertTrue(self.is_view_of(t, v)) 9922 9923 v[0, 0] = 0 9924 self.assertEqual(t[0, 0], v[0, 0]) 9925 9926 def test_basic_indexing_ellipses_view(self, device="mps"): 9927 t = torch.ones(5, 5, device=device) 9928 v = t[..., :2] 9929 self.assertTrue(self.is_view_of(t, v)) 9930 9931 v[0, 0] = 0 9932 self.assertEqual(t[0, 0], v[0, 0]) 9933 9934 def test_basic_indexing_newaxis_view(self, device="mps"): 9935 t = torch.ones(5, 5, device=device) 9936 v = t[None, :2, 3] 9937 self.assertTrue(self.is_view_of(t, v)) 9938 9939 v[0, 0] = 0 9940 self.assertEqual(t[0, 3], v[0, 0]) 9941 9942 def test_chunk_view(self, device="mps"): 9943 t = torch.zeros(3, 3, device=device) 9944 l = torch.chunk(t, 3) 9945 9946 for idx, v in enumerate(l): 9947 self.assertTrue(self.is_view_of(t, v)) 9948 9949 v[0, 0] = idx + 1 9950 self.assertEqual(t[idx, 0], v[0, 0]) 9951 9952 def test_split_view(self, device="mps"): 9953 t = torch.zeros(3, 3, device=device) 9954 l = torch.split(t, [1, 1, 1]) 9955 9956 for idx, v in enumerate(l): 9957 self.assertTrue(self.is_view_of(t, v)) 9958 9959 v[0, 0] = idx + 1 9960 self.assertEqual(t[idx, 0], v[0, 0]) 9961 9962 def test_movedim_view(self, device="mps"): 9963 def run_test(device, op): 9964 t = torch.zeros(3, 3, device=device) 9965 out = op(t) 9966 9967 self.assertTrue(self.is_view_of(t, out)) 9968 9969 # Randomly change values in output 9970 # and verify that original is changed 9971 # as well. 9972 for _ in range(3): 9973 idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) 9974 out[idx_1, idx_2] = random.random() 9975 self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) 9976 9977 for fn in [torch.movedim, torch.moveaxis]: 9978 op = partial(fn, source=(0, 1), destination=(1, 0)) 9979 run_test(device, op) 9980 9981 op = partial(fn, source=0, destination=1) 9982 run_test(device, op) 9983 9984 # Testing that the generated view_copy kernel and its derivative are implemented correctly 9985 def test_view_copy(self, device="mps"): 9986 a = torch.randn(4, device=device, requires_grad=True) 9987 a_ref = a.clone().detach().requires_grad_() 9988 a_view = a_ref.view(2, 2) 9989 a_view_copy = torch.view_copy(a, (2, 2)) 9990 9991 # view_copy ops don't preserve view relationship 9992 self.assertTrue(self.is_view_of(a_ref, a_view)) 9993 self.assertFalse(self.is_view_of(a, a_view_copy)) 9994 9995 a_view_copy.sum().backward() 9996 a_view.sum().backward() 9997 9998 # forward and backward give the same shape + result 9999 self.assertEqual(a_view_copy, a_view) 10000 self.assertEqual(a.grad, a_ref.grad) 10001 10002 def test_view_copy_out(self, device="mps"): 10003 a = torch.randn(2, 2, device=device) 10004 out = torch.empty(2, device=device) 10005 10006 torch.diagonal_copy(a, out=out) 10007 expected = torch.diagonal_copy(a) 10008 10009 self.assertEqual(expected, out) 10010 10011 a = torch.randn(4, device=device) 10012 out1 = torch.empty(2, device=device) 10013 out2 = torch.empty(2, device=device) 10014 10015 torch.split_copy(a, 2, out=(out1, out2)) 10016 expected1, expected2 = torch.split_copy(a, 2) 10017 10018 self.assertEqual(expected1, out1) 10019 self.assertEqual(expected2, out2) 10020 10021 def test_detached_view_copy(self, device="mps"): 10022 # https://github.com/pytorch/pytorch/issues/86052 10023 x = torch.arange(2) 10024 # .detach() makes y not a view, but contig tensor 10025 # with non-zero offset 10026 y = x[1].detach() 10027 z = y.to(device) 10028 self.assertEqual(y, z.cpu()) 10029 10030 def test_empty_reshape(self, device="mps"): 10031 x = torch.randn(0, 6, device=device) 10032 self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) 10033 # should be viewable -- i.e. data_ptr is the same. 10034 self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) 10035 10036 # match NumPy semantics -- don't infer the size of dimension with a degree of freedom 10037 self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) 10038 10039 def test_expand(self, device="mps"): 10040 tensor = torch.rand(1, 8, 1, device=device) 10041 tensor2 = torch.rand(5, device=device) 10042 template = torch.rand(4, 8, 5, device=device) 10043 target = template.size() 10044 self.assertEqual(tensor.expand_as(template).size(), target) 10045 self.assertEqual(tensor.expand(4, 8, 5).size(), target) 10046 self.assertEqual(tensor.expand(target).size(), target) 10047 self.assertEqual(tensor2.expand_as(template).size(), target) 10048 self.assertEqual(tensor2.expand(4, 8, 5).size(), target) 10049 self.assertEqual(tensor2.expand(target).size(), target) 10050 10051 # test double expand 10052 self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) 10053 10054 # test non-contiguous 10055 noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] 10056 self.assertFalse(noncontig.is_contiguous()) 10057 self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) 10058 10059 # make sure it's compatible with unsqueeze 10060 expanded = tensor2.expand(1, 1, 5) 10061 unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) 10062 self.assertEqual(expanded, unsqueezed) 10063 self.assertEqual(expanded.stride(), unsqueezed.stride()) 10064 10065 # test -1 as target size 10066 self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) 10067 self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) 10068 10069 # test expanding empty to empty 10070 self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)) 10071 10072 def test_view_empty(self, device="mps"): 10073 x = torch.randn(0, 6, device=device) 10074 self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) 10075 10076 def test_reshape(self, device="mps"): 10077 x = torch.randn(3, 3, device=device) 10078 self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) 10079 self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) 10080 self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) 10081 self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) 10082 10083 y = torch.randn(4, 4, 4, device=device)[:, 0, :] 10084 # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape 10085 if device != "meta": 10086 self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) 10087 self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) 10088 self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) 10089 10090 s = torch.randn((), device=device) 10091 self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) 10092 self.assertEqual(s.reshape(-1).shape, (1,)) 10093 self.assertRaises(RuntimeError, lambda: s.reshape(2)) 10094 10095 empty = torch.tensor([], device=device) 10096 self.assertEqual(empty, empty.reshape(-1)) 10097 self.assertEqual(empty, empty.reshape([0])) 10098 # TODO: fix these once we have multi-dimensional empty tensors 10099 self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) 10100 self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) 10101 self.assertRaises(RuntimeError, lambda: empty.reshape(1)) 10102 10103 x = torch.randn(3, 3, device=device) 10104 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) 10105 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) 10106 self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))) 10107 10108 def test_narrow(self, device="mps"): 10109 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 10110 self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]])) 10111 self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]])) 10112 self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]])) 10113 self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]])) 10114 self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]])) 10115 self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) 10116 self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]])) 10117 self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]])) 10118 10119 def test_narrow_tensor(self, device="mps"): 10120 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 10121 self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]])) 10122 with self.assertRaises(Exception): 10123 x.narrow(0, torch.tensor(0.), 1) 10124 with self.assertRaises(Exception): 10125 x.narrow(0, torch.tensor([0]), 1) 10126 with self.assertRaises(Exception): 10127 x.narrow(0, torch.tensor([0, 1]), 1) 10128 10129 def test_t(self, device="mps"): 10130 # Test 0D tensors 10131 x = torch.randn(()) 10132 self.assertEqual(x, x.t()) 10133 x = x.to_sparse() 10134 self.assertEqual(x, x.t()) 10135 10136 # Test 1D tensors 10137 x = torch.arange(4) 10138 self.assertEqual(x, x.t()) 10139 x = x.to_sparse() 10140 self.assertEqual(x, x.t()) 10141 10142 # Test 2D tensors 10143 x = torch.rand((2, 2)) 10144 self.assertEqual(x.t(), x.transpose(0, 1)) 10145 x = x.to_sparse() 10146 self.assertEqual(x.t(), x.transpose(0, 1)) 10147 10148 # Test 3D tensor 10149 x = torch.rand((2, 2, 2)) 10150 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): 10151 x.t() 10152 x = x.to_sparse() 10153 with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): 10154 x.t() 10155 10156 def test_split(self, device="mps"): 10157 tensor = torch.rand(7, 4) 10158 split_size = 3 10159 dim = 0 10160 target_sizes = ([3, 4], [3, 4], [1, 4]) 10161 splits = tensor.split(split_size, dim) 10162 start = 0 10163 for target_size, split in zip(target_sizes, splits): 10164 self.assertEqual(split.size(), target_size) 10165 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) 10166 start = start + target_size[dim] 10167 10168 # Variable sections split 10169 tensor = torch.randn(20, 10) 10170 dim = 0 10171 split_sizes = [5, 5, 10] 10172 target_sizes = ([[5, 10], [5, 10], [10, 10]]) 10173 splits = tensor.split(split_sizes, dim) 10174 start = 0 10175 for target_size, split in zip(target_sizes, splits): 10176 self.assertEqual(split.size(), target_size) 10177 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) 10178 start = start + target_size[dim] 10179 10180 split_sizes = [2, 2, 6] 10181 target_sizes = ([20, 2], [20, 2], [20, 6]) 10182 dim = 1 10183 splits = tensor.split(split_sizes, dim) 10184 start = 0 10185 for target_size, split in zip(target_sizes, splits): 10186 self.assertEqual(split.size(), target_size) 10187 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) 10188 start = start + target_size[dim] 10189 10190 def test_chunk(self, device="mps"): 10191 tensor = torch.rand(4, 7) 10192 num_chunks = 3 10193 dim = 1 10194 target_sizes = ([4, 3], [4, 3], [4, 1]) 10195 splits = tensor.chunk(num_chunks, dim) 10196 start = 0 10197 for target_size, split in zip(target_sizes, splits): 10198 self.assertEqual(split.size(), target_size) 10199 self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 10200 atol=0, rtol=0) 10201 start = start + target_size[dim] 10202 10203 # Invalid chunk sizes 10204 error_regex = 'chunk expects.*greater than 0' 10205 with self.assertRaisesRegex(RuntimeError, error_regex): 10206 tensor.chunk(0) 10207 with self.assertRaisesRegex(RuntimeError, error_regex): 10208 tensor.chunk(-2) 10209 10210 def test_unsqueeze(self, device="mps") -> None: 10211 x = torch.randn(2, 3, 4) 10212 y = x.unsqueeze(1) 10213 self.assertEqual(y, x.view(2, 1, 3, 4)) 10214 y = x.clone().unsqueeze_(2) 10215 self.assertEqual(y, x.view(2, 3, 1, 4)) 10216 10217 x = x[:, 1] 10218 self.assertFalse(x.is_contiguous()) 10219 y = x.unsqueeze(1) 10220 self.assertEqual(y, x.contiguous().view(2, 1, 4)) 10221 y = x.clone().unsqueeze_(2) 10222 self.assertEqual(y, x.contiguous().view(2, 4, 1)) 10223 10224 # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) 10225 def test_big_transpose(self, device="mps"): 10226 t = torch.rand(456, 789, device=device) 10227 t1 = t.t().contiguous() 10228 t2 = torch.from_numpy(t.cpu().numpy().transpose()) 10229 self.assertEqual(t1, t2) 10230 10231 def test_T(self, device="mps"): 10232 a = torch.randn(2, 3, 4, device=device) 10233 t1 = a.T 10234 t2 = a.permute(2, 1, 0) 10235 self.assertEqual(t2, t1) 10236 b = torch.randn(10, device=device) 10237 self.assertEqual(b, b.T) 10238 10239 def test_transposes(self, device="mps", dtype=torch.float32): 10240 for op in ("T", "H", "mT", "mH", "adjoint"): 10241 shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),) 10242 for shape in shapes: 10243 a = make_tensor(shape, device=device, dtype=dtype) 10244 t1 = getattr(a, op) 10245 if op == "adjoint": 10246 t1 = t1() 10247 t2 = a 10248 if a.ndim != 0: 10249 t2 = t2.transpose(-2, -1) 10250 if op[-1] == "H" or op == "adjoint": 10251 t2 = t2.conj() 10252 self.assertEqual(t2, t1) 10253 10254 def test_transposes_errors(self, device="mps", dtype=torch.float32): 10255 for op in ("H", "mT", "mH", "adjoint"): 10256 shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) 10257 for shape in shapes: 10258 a = make_tensor(shape, device=device, dtype=dtype) 10259 with self.assertRaisesRegex(RuntimeError, "only supported on matrices"): 10260 t1 = getattr(a, op) 10261 if op == "adjoint": 10262 t1 = t1() 10263 10264 def test_python_types(self, device="mps"): 10265 a1 = torch.randn((1, 2), device=device, dtype=torch.float32) 10266 a2 = torch.randn((1, 2), device=device, dtype=torch.float32) 10267 self.assertEqual(a1.dtype, a2.dtype) 10268 10269 b1 = torch.arange(10, 20, dtype=torch.int64, device=device) 10270 b2 = torch.arange(10, 20, dtype=int, device=device) 10271 self.assertEqual(b1.dtype, b2.dtype) 10272 10273 c1 = torch.tensor([True, False], dtype=torch.bool, device=device) 10274 c2 = torch.tensor([True, False], dtype=bool, device=device) 10275 self.assertEqual(c1.dtype, c2.dtype) 10276 10277 # TODO: is resize best put in test_view_ops? 10278 def test_resize_as_preserves_strides(self, device="mps"): 10279 x = torch.empty(2, 3).t() 10280 old_strides = x.stride() 10281 x.resize_as_(x) 10282 self.assertEqual(x.stride(), old_strides) 10283 10284 def test_memory_format_resize_as(self, device="mps"): 10285 def test_helper(shape, memory_format, device="mps"): 10286 xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format) 10287 flat = torch.randn(xc.numel(), device=device) 10288 flat.resize_as_(xc, memory_format=torch.preserve_format) 10289 self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 10290 10291 test_helper((10, 3, 32, 32), torch.channels_last, device="mps") 10292 test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps") 10293 10294 def test_memory_format_resize_(self, device="mps"): 10295 def test_helper(shape, numel, memory_format, device="mps"): 10296 flat = torch.randn(numel, device=device) 10297 flat.resize_(shape, memory_format=memory_format) 10298 self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 10299 10300 test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps") 10301 test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps") 10302 10303 # TODO: OpInfo this 10304 def _test_atleast(self, device, torch_fn): 10305 # 0-dim 10306 s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) 10307 10308 gradcheck(lambda x: torch_fn(x), s) 10309 gradgradcheck(lambda x: torch_fn(x), s) 10310 10311 # 1-dim 10312 a = torch.rand(4, dtype=torch.double, requires_grad=True) 10313 10314 gradcheck(lambda x: torch_fn(x), a) 10315 gradgradcheck(lambda x: torch_fn(x), a) 10316 10317 # 2,3,4-dim 10318 b = torch.rand(4, 3, dtype=torch.double, requires_grad=True) 10319 c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True) 10320 d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True) 10321 10322 input_tuple = (s, a, b, c, d) 10323 gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 10324 gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 10325 10326 def test_atleast_gradient(self, device="mps"): 10327 self._test_atleast(device, torch.atleast_1d) 10328 self._test_atleast(device, torch.atleast_2d) 10329 self._test_atleast(device, torch.atleast_3d) 10330 10331 def test_view(self, device="mps"): 10332 tensor = torch.rand(15, device=device) 10333 template = torch.rand(3, 5, device=device) 10334 empty = torch.empty(0, device=device) 10335 target = template.size() 10336 self.assertEqual(tensor.view_as(template).size(), target) 10337 self.assertEqual(tensor.view(3, 5).size(), target) 10338 self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) 10339 self.assertEqual(tensor.view(-1, 5).size(), target) 10340 self.assertEqual(tensor.view(3, -1).size(), target) 10341 tensor_view = tensor.view(5, 3) 10342 tensor_view.fill_(random.uniform(0, 1)) 10343 self.assertEqual(empty.view_as(empty), empty) 10344 self.assertEqual(empty.view(0), empty) 10345 self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) 10346 self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) 10347 10348 # test size inference with empty tensors 10349 self.assertEqual(empty.view(-1).size(), torch.Size([0])) 10350 self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) 10351 10352 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): 10353 empty.view(-1, 0) 10354 10355 with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): 10356 empty.view(3, 0, -1, 0) 10357 10358 self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) 10359 self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) 10360 self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) 10361 10362 def test_contiguous(self, device="mps"): 10363 x = torch.randn(1, 16, 5, 5, device=device) 10364 self.assertTrue(x.is_contiguous()) 10365 stride = list(x.stride()) 10366 stride[0] = 20 10367 # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 10368 x.set_(x.storage(), 0, x.size(), stride) 10369 self.assertTrue(x.is_contiguous()) 10370 10371 def test_resize_mps_dtypes(self, device="mps"): 10372 shape = (2, 2) 10373 for dt in MPS_DTYPES: 10374 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 10375 x.resize_(shape) 10376 self.assertEqual(shape, x.shape) 10377 10378 def test_resize_as_mps_dtypes(self, device="mps"): 10379 for dt in MPS_DTYPES: 10380 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 10381 y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) 10382 x.resize_as_(y) 10383 self.assertEqual(y.shape, x.shape) 10384 10385 def test_resize_overflow(self, device="mps"): 10386 x = torch.empty((), dtype=torch.float64) 10387 with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'): 10388 x.resize_([2, 4, 2**29, 2**29]) 10389 with self.assertRaisesRegex(RuntimeError, 'overflow'): 10390 x.resize_([8, 8, 2**29, 2**29]) 10391 10392 def test_view_all_dtypes_and_devices(self, device="mps"): 10393 for dt in (torch.float, torch.bool): 10394 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 10395 self.assertEqual(x.view(6).shape, [6]) 10396 10397class TestConvolutionMPS(TestCaseMPS): 10398 def test_conv1d_all_strides_paddings(self): 10399 # https://github.com/pytorch/pytorch/issues/82921 10400 def helper(stride, padding): 10401 y_cpu = torch.randn(1, 57, 40) 10402 conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False) 10403 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps') 10404 x_cpu = conv_cpu(y_cpu) 10405 10406 y_gpu = y_cpu.to(device='mps') 10407 x_gpu = conv_gpu(y_gpu) 10408 self.assertEqual(x_cpu, x_gpu.cpu()) 10409 for stride in range(1, 4): 10410 for padding in range(1, 4): 10411 helper(stride, padding) 10412 10413 10414 def test_conv1d_channels_last(self): 10415 # https://github.com/pytorch/pytorch/issues/81557 10416 model_cpu = torch.nn.Conv1d(1, 128, 3) 10417 a_cpu = torch.arange((128 * 176), dtype=torch.float32) 10418 a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1) 10419 out_cpu = model_cpu(a_cpu) 10420 10421 a_mps = a_cpu.detach().clone().to("mps") 10422 model_mps = model_cpu.to("mps") 10423 out_mps = model_mps(a_mps) 10424 10425 self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04) 10426 10427 def test_conv_transpose_1d_all_strides(self): 10428 # https://github.com/pytorch/pytorch/issues/82711 10429 def helper(stride): 10430 y_cpu = torch.ones(1, 1, 2) 10431 deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1) 10432 deconv_cpu.weight.data = torch.ones(1, 1, 2) 10433 deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps') 10434 x_cpu = deconv_cpu(y_cpu) 10435 10436 y_gpu = y_cpu.to(device='mps') 10437 x_gpu = deconv_gpu(y_gpu) 10438 self.assertEqual(x_cpu, x_gpu.cpu()) 10439 [helper(stride) for stride in [1, 2, 3]] 10440 10441 def test_conv_transpose_1d_nn_functional(self): 10442 # https://github.com/pytorch/pytorch/issues/82563 10443 tin = torch.rand((1, 512, 1245), dtype=torch.float32) 10444 tparams = torch.rand((512, 256, 16), dtype=torch.float32) 10445 tbias = torch.rand((256), dtype=torch.float32) 10446 10447 device = 'cpu' 10448 tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4) 10449 10450 device = 'mps' 10451 tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4) 10452 10453 self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04) 10454 10455 def test_conv_backward_1d_channels_last(self): 10456 def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1): 10457 # https://github.com/pytorch/pytorch/issues/84511 10458 conv_cpu = torch.nn.Conv1d( 10459 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_() 10460 conv_mps = torch.nn.Conv1d( 10461 in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps") 10462 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True) 10463 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True) 10464 10465 10466 data = torch.rand(shape, dtype=torch.float32) 10467 x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True) 10468 x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True) 10469 res_cpu = conv_cpu(x_cpu) 10470 res_mps = conv_mps(x_mps) 10471 self.assertEqual(res_cpu, res_mps) 10472 res_cpu = res_cpu.sum().backward() 10473 res_mps = res_mps.sum().backward() 10474 10475 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04) 10476 self.assertEqual(x_cpu.grad, x_mps.grad) 10477 10478 helper(shape=(1, 176, 1)) 10479 helper(shape=(2, 12, 1)) 10480 helper(shape=(3, 176, 1)) 10481 helper(shape=(4, 376, 1)) 10482 helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1) 10483 helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3) 10484 10485 def test_conv1d_contiguous(self): 10486 model_cpu = torch.nn.Conv1d(1, 128, 3) 10487 a_cpu = torch.ones(128, 1, 176) 10488 out_cpu = model_cpu(a_cpu) 10489 10490 a_mps = a_cpu.detach().clone().to("mps") 10491 model_mps = model_cpu.to("mps") 10492 out_mps = model_mps(a_mps) 10493 10494 self.assertEqual(out_cpu.shape, out_mps.shape) 10495 self.assertEqual(out_cpu, out_mps.cpu()) 10496 10497 def test_conv2d_all_strides_paddings(self): 10498 # https://github.com/pytorch/pytorch/issues/83180 10499 def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data): 10500 x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_() 10501 x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_() 10502 10503 if permute_data: 10504 x_cpu.permute(0, 2, 3, 1) 10505 x_mps.permute(0, 2, 3, 1) 10506 10507 for strideX in range(1, 4): 10508 for strideY in range(1, 4): 10509 conv_cpu = torch.nn.Conv2d( 10510 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_() 10511 conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_() 10512 10513 conv_mps = torch.nn.Conv2d( 10514 in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps") 10515 conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10516 conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10517 10518 res_cpu = conv_cpu(x_cpu) 10519 res_mps = conv_mps(x_mps) 10520 self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05) 10521 res_cpu = res_cpu.sum().backward() 10522 res_mps = res_mps.sum().backward() 10523 self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04) 10524 10525 self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04) 10526 self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad) 10527 self.assertEqual(x_cpu.grad, x_mps.grad) 10528 10529 for mem_format_input in [torch.contiguous_format, torch.channels_last]: 10530 for mem_format_weight in [torch.contiguous_format, torch.channels_last]: 10531 for permute_data in [True, False]: 10532 helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data) 10533 helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data) 10534 helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data) 10535 10536 def test_conv_transpose_2d_strided(self): 10537 def helper(m_cpu, memory_format): 10538 m_mps = copy.deepcopy(m_cpu).requires_grad_() 10539 m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10540 m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10541 10542 input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_() 10543 input_mps = input_cpu.detach().clone().to("mps") 10544 10545 output_cpu = m_cpu(input_cpu) 10546 output_mps = m_mps(input_mps) 10547 self.assertEqual(output_cpu, output_mps) 10548 10549 for mem_format_input in [torch.contiguous_format, torch.channels_last]: 10550 # With square kernels and equal stride 10551 helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input) 10552 10553 # non-square kernels and unequal stride and with padding 10554 helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input) 10555 10556 def test_conv_transpose_2d_specified_output(self): 10557 input_cpu = torch.randn(1, 16, 12, 12) 10558 input_mps = input_cpu.detach().clone().to("mps") 10559 10560 downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1) 10561 downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps") 10562 downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10563 downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10564 10565 upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) 10566 upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps") 10567 upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_() 10568 upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_() 10569 10570 h_cpu = downsample_cpu(input_cpu) 10571 h_mps = downsample_mps(input_mps) 10572 self.assertEqual(h_cpu, h_mps) 10573 10574 size_cpu = h_cpu.size() 10575 size_mps = h_mps.size() 10576 self.assertEqual(size_cpu, size_mps) 10577 10578 output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size()) 10579 output_mps = upsample_mps(h_mps, output_size=input_mps.size()) 10580 self.assertEqual(output_cpu, output_mps) 10581 self.assertEqual(output_cpu.size(), output_mps.size()) 10582 10583 def test_conv2d_single_stride(self): 10584 y_cpu = torch.randn(2, 2, 3, 6) 10585 y_gpu = y_cpu.to(device='mps') 10586 for stride in range(1, 4): 10587 conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride) 10588 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps') 10589 x_cpu = conv_cpu(y_cpu) 10590 x_gpu = conv_gpu(y_gpu) 10591 self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05) 10592 10593 @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12") 10594 def test_conv3d_single_stride(self): 10595 # Conv3d is only available from MacOS 13.2 onwards 10596 y_cpu = torch.randn(2, 2, 3, 6) 10597 y_gpu = y_cpu.to(device='mps') 10598 for stride in range(1, 4): 10599 conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride) 10600 conv_gpu = copy.deepcopy(conv_cpu).to(device='mps') 10601 x_cpu = conv_cpu(y_cpu) 10602 x_gpu = conv_gpu(y_gpu) 10603 self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05) 10604 10605 def test_grid_sample(self): 10606 def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad): 10607 def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners): 10608 for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]: 10609 # grid_dim_contig_order specifies the dimension order that can 10610 # make grid to be contiguous. 10611 # i.e., grid.permute(grid_dim_contig_order) is contiguous. 10612 # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be 10613 # initialized with contiguous tensor of shape [N, 2, H, W] 10614 # and permuted to [N, H, W, 2] afterwards. 10615 grid_shape = [N, H, W, 2] 10616 grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order] 10617 grid_fwd_permute = [None, None, None, None] 10618 for i, d in enumerate(grid_dim_contig_order): 10619 grid_fwd_permute[d] = i 10620 10621 def get_grid(device='cpu', data=None): 10622 if data is not None: 10623 assert list(data.shape) == grid_shape 10624 data = data.permute(grid_dim_contig_order).to(device) 10625 else: 10626 data = torch.randn(grid_init_shape, device=device) 10627 grid = data.permute(grid_fwd_permute) 10628 assert grid.permute(grid_dim_contig_order).is_contiguous() 10629 return grid 10630 10631 input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad) 10632 grid_cpu = get_grid().requires_grad_() 10633 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 10634 align_corners=align_corners) 10635 self.assertEqual(out_cpu.size(), torch.Size([N, C, H, W])) 10636 10637 gradients = torch.randn_like(out_cpu) 10638 out_cpu.backward(gradients) 10639 10640 10641 # Compare against unvectorized CPU fallback 10642 10643 # NOTE [ grid_sample CPU fallback ] 10644 # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for 10645 # 32-bit floats. So we also have a fallback that is used only for float tensors 10646 # requiring 64-bit indexing. That requires too much memory to run on CI, so we 10647 # also export the fallback and test it here to ensure feature parity with 10648 # the vectorized version. 10649 input_fallback = input_cpu.float().detach_().requires_grad_() 10650 grid_fallback = grid_cpu.float().detach_().requires_grad_() 10651 out_fallback = torch._grid_sampler_2d_cpu_fallback( 10652 input_fallback, grid_fallback, 10653 F.GRID_SAMPLE_INTERPOLATION_MODES[mode], 10654 F.GRID_SAMPLE_PADDING_MODES[padding_mode], 10655 align_corners) 10656 self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5) 10657 10658 out_fallback.backward(gradients.float()) 10659 if input_requires_grad: 10660 self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5) 10661 self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5) 10662 10663 input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad) 10664 grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_() 10665 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners) 10666 self.assertEqual(out_cpu, out_mps) 10667 out_mps.backward(gradients.to("mps")) 10668 if input_requires_grad: 10669 self.assertEqual(input_cpu.grad, input_mps.grad) 10670 self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0) 10671 10672 # check that zero-dimensional input strides don't error out 10673 base_input = torch.randn(N, C, 1, IW) 10674 input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad) 10675 out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode, 10676 align_corners=align_corners) 10677 10678 input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad) 10679 out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners) 10680 self.assertEqual(out_cpu, out_mps) 10681 10682 # test same size output 10683 test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners) 10684 10685 # test larger output 10686 N = random.randint(2, 8) 10687 C = random.randint(2, 8) 10688 IH = random.randint(2, 8) 10689 IW = random.randint(2, 8) 10690 H = random.randint(IH + 1, 12) 10691 W = random.randint(IW + 1, 12) 10692 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 10693 10694 # test smaller output 10695 N = random.randint(2, 8) 10696 C = random.randint(2, 8) 10697 IH = random.randint(2, 8) 10698 IW = random.randint(2, 8) 10699 H = random.randint(2, IH) 10700 W = random.randint(2, IW) 10701 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 10702 10703 # test 1x1 inpput 10704 N = random.randint(2, 8) 10705 C = random.randint(2, 8) 10706 IH = 1 10707 IW = 1 10708 H = random.randint(2, 5) 10709 W = random.randint(2, 5) 10710 test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners) 10711 10712 # testing empty grid 10713 N = random.randint(2, 8) 10714 C = random.randint(2, 8) 10715 IH = random.randint(2, 8) 10716 IW = random.randint(2, 8) 10717 W = random.randint(3, IW + 2) 10718 test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners) 10719 10720 # testing empty channel 10721 N = random.randint(2, 8) 10722 IH = random.randint(2, 8) 10723 IW = random.randint(2, 8) 10724 H = random.randint(3, IH + 2) 10725 W = random.randint(3, IW + 2) 10726 test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners) 10727 10728 # testing empty batch 10729 C = random.randint(2, 8) 10730 IH = random.randint(2, 8) 10731 IW = random.randint(2, 8) 10732 H = random.randint(3, IH + 2) 10733 W = random.randint(3, IW + 2) 10734 test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners) 10735 10736 for mode in ('bilinear', 'nearest'): 10737 for padding_mode in ('zeros', 'reflection'): 10738 for align_corners in (True, False): 10739 # test known input 10740 input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5) 10741 grid = torch.tensor( 10742 [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]], 10743 [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2) 10744 if mode == 'bilinear': 10745 if padding_mode == 'zeros': 10746 if align_corners: 10747 groundtruth = torch.tensor( 10748 [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000], 10749 [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5) 10750 else: 10751 groundtruth = torch.tensor( 10752 [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250], 10753 [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5) 10754 elif padding_mode == 'border': 10755 if align_corners: 10756 groundtruth = torch.tensor( 10757 [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000], 10758 [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5) 10759 else: 10760 groundtruth = torch.tensor( 10761 [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500], 10762 [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5) 10763 elif padding_mode == 'reflection': 10764 if align_corners: 10765 groundtruth = torch.tensor( 10766 [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000], 10767 [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5) 10768 else: 10769 groundtruth = torch.tensor( 10770 [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500], 10771 [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5) 10772 else: 10773 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 10774 elif mode == 'nearest': 10775 if padding_mode == 'zeros': 10776 if align_corners: 10777 groundtruth = torch.tensor( 10778 [[0., 8., 5., 7., 9.], 10779 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5) 10780 else: 10781 groundtruth = torch.tensor( 10782 [[0., 8., 5., 7., 0.], 10783 [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5) 10784 elif padding_mode == 'border': 10785 if align_corners: 10786 groundtruth = torch.tensor( 10787 [[1., 8., 5., 7., 9.], 10788 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5) 10789 else: 10790 groundtruth = torch.tensor( 10791 [[1., 8., 5., 7., 9.], 10792 [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5) 10793 elif padding_mode == 'reflection': 10794 if align_corners: 10795 groundtruth = torch.tensor( 10796 [[1., 8., 5., 7., 9.], 10797 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5) 10798 else: 10799 groundtruth = torch.tensor( 10800 [[1., 8., 5., 7., 9.], 10801 [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5) 10802 else: 10803 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 10804 elif mode == 'bicubic': 10805 if padding_mode == 'zeros': 10806 if align_corners: 10807 groundtruth = torch.tensor( 10808 [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000], 10809 [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5) 10810 else: 10811 groundtruth = torch.tensor( 10812 [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264], 10813 [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5) 10814 elif padding_mode == 'border': 10815 if align_corners: 10816 groundtruth = torch.tensor( 10817 [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000], 10818 [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5) 10819 else: 10820 groundtruth = torch.tensor( 10821 [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781], 10822 [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5) 10823 elif padding_mode == 'reflection': 10824 if align_corners: 10825 groundtruth = torch.tensor( 10826 [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000], 10827 [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5) 10828 else: 10829 groundtruth = torch.tensor( 10830 [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531], 10831 [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5) 10832 else: 10833 raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'") 10834 10835 else: 10836 raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'") 10837 output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode, 10838 align_corners=align_corners) 10839 self.assertEqual(output, groundtruth, atol=1e-5, rtol=0, 10840 msg=f"groundtruth comparison failed for mode={mode}, " 10841 f"padding_mode={padding_mode}") 10842 10843class TestAdvancedIndexing(TestCaseMPS): 10844 supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8] 10845 supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8] 10846 10847 @unittest.skipIf(product_version < 14.0, "Skipped on macOS < 14") 10848 def test_nonzero_no_warning(self): 10849 device = "mps" 10850 t = torch.randn((2, 2), device=device) 10851 with warnings.catch_warnings(record=True) as w: 10852 warnings.simplefilter("always") 10853 torch.nonzero(t) 10854 t.nonzero() 10855 self.assertEqual(len(w), 0) 10856 10857 def test_nonzero(self): 10858 def helper(dtype): 10859 device = "mps" 10860 shapes = [ 10861 torch.Size((12,)), 10862 torch.Size((12, 1)), 10863 torch.Size((1, 12)), 10864 torch.Size((6, 2)), 10865 torch.Size((3, 2, 2)), 10866 torch.Size((5, 5, 5)), 10867 ] 10868 10869 def gen_nontrivial_input(shape, dtype, device): 10870 if dtype != torch.bfloat16: 10871 return torch.randint(2, shape, device=device, dtype=dtype) 10872 else: 10873 # windows does not work for bfloat16 randing 10874 return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype) 10875 10876 for shape in shapes: 10877 tensor = gen_nontrivial_input(shape, dtype, device) 10878 dst1 = torch.nonzero(tensor, as_tuple=False) 10879 dst2 = tensor.nonzero(as_tuple=False) 10880 dst3 = torch.empty([], dtype=torch.long, device=device) 10881 dst3 = dst3.resize_(0) 10882 torch.nonzero(tensor, out=dst3) 10883 np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy() 10884 np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() 10885 self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) 10886 self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) 10887 self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0) 10888 tup1 = torch.nonzero(tensor, as_tuple=True) 10889 tup2 = tensor.nonzero(as_tuple=True) 10890 tup1 = torch.stack(tup1).t().cpu() 10891 tup2 = torch.stack(tup2).t().cpu() 10892 self.assertEqual(tup1, np_result, atol=0, rtol=0) 10893 self.assertEqual(tup2, np_result, atol=0, rtol=0) 10894 [helper(dtype) for dtype in self.supported_dtypes] 10895 10896 def test_nonzero_astuple_out(self): 10897 device = "mps" 10898 t = torch.randn((3, 3, 3), device=device) 10899 out = torch.empty([], dtype=torch.long, device=device) 10900 out = out.resize_(0) 10901 10902 with self.assertRaises(RuntimeError): 10903 torch.nonzero(t, as_tuple=True, out=out) 10904 10905 self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)) 10906 10907 # Verifies that JIT script cannot handle the as_tuple kwarg 10908 # See Issue https://github.com/pytorch/pytorch/issues/45499. 10909 def _foo(t): 10910 tuple_result = torch.nonzero(t, as_tuple=True) 10911 nontuple_result = torch.nonzero(t, as_tuple=False) 10912 out = torch.empty_like(nontuple_result) 10913 torch.nonzero(t, as_tuple=False, out=out) 10914 return tuple_result, nontuple_result, out 10915 10916 with self.assertRaises(RuntimeError): 10917 scripted_foo = torch.jit.script(_foo) 10918 10919 # Verifies that JIT tracing works fine 10920 traced_foo = torch.jit.trace(_foo, t) 10921 traced_tuple, traced_nontuple, traced_out = traced_foo(t) 10922 expected_tuple = torch.nonzero(t, as_tuple=True) 10923 expected_nontuple = torch.nonzero(t) 10924 10925 self.assertEqual(traced_tuple, expected_tuple) 10926 self.assertEqual(traced_nontuple, expected_nontuple) 10927 self.assertEqual(traced_out, expected_nontuple) 10928 10929 def test_nonzero_discontiguous(self): 10930 device = "mps" 10931 shape = (4, 4) 10932 tensor = torch.randint(2, shape, device=device) 10933 tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor) 10934 dst1 = tensor.nonzero(as_tuple=False) 10935 dst2 = tensor_nc.nonzero(as_tuple=False) 10936 self.assertEqual(dst1, dst2, atol=0, rtol=0) 10937 dst3 = torch.empty_like(dst1) 10938 data_ptr = dst3.data_ptr() 10939 # expect dst3 storage to be reused 10940 torch.nonzero(tensor, out=dst3) 10941 self.assertEqual(data_ptr, dst3.data_ptr()) 10942 self.assertEqual(dst1, dst3, atol=0, rtol=0) 10943 # discontiguous out 10944 dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2] 10945 data_ptr = dst4.data_ptr() 10946 strides = dst4.stride() 10947 torch.nonzero(tensor, out=dst4) 10948 self.assertEqual(data_ptr, dst4.data_ptr()) 10949 self.assertEqual(dst1, dst4, atol=0, rtol=0) 10950 self.assertEqual(strides, dst4.stride()) 10951 10952 def test_nonzero_non_diff(self): 10953 device = "mps" 10954 x = torch.randn(10, requires_grad=True, device=device) 10955 nz = x.nonzero() 10956 self.assertFalse(nz.requires_grad) 10957 10958 def test_nonzero_multi_threading(self): 10959 # Test that MPS doesn't crash if nonzero called concurrently 10960 # See https://github.com/pytorch/pytorch/issues/100285 10961 x = torch.rand(3, 3, device="mps") 10962 t1 = threading.Thread(target=torch.nonzero, args=(x,)) 10963 t2 = threading.Thread(target=torch.nonzero, args=(x,)) 10964 t1.start() 10965 t2.start() 10966 10967 def test_sliced_view_cast(self): 10968 # This used to crash on MacOS Sequoia 10969 # See https://github.com/pytorch/pytorch/issues/137800 10970 x = torch.rand(16, 16, device='mps', dtype=torch.float16) 10971 y = x[:, 0:2].view(torch.float32) + 1 10972 10973 def test_masked_select(self): 10974 x = torch.randn(3, 4) 10975 x_mps = x.to("mps") 10976 mask = x.ge(0.5) 10977 mask_mps = x_mps.ge(0.5) 10978 10979 res = torch.masked_select(x, mask) 10980 res_mps = torch.masked_select(x_mps, mask_mps) 10981 10982 self.assertEqual(res, res_mps) 10983 10984 # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm 10985 def test_indexing_get(self): 10986 def helper(dtype): 10987 x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype) 10988 x_mps = x_cpu.detach().clone().to("mps") 10989 10990 y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]] 10991 y_mps = x_mps[[0, 1, 2], [0, 1, 0]] 10992 self.assertEqual(y_cpu, y_mps, str(dtype)) 10993 [helper(dtype) for dtype in self.supported_dtypes] 10994 10995 def test_indexing_select_corners(self): 10996 def helper(dtype): 10997 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) 10998 x_mps = x_cpu.detach().clone().to("mps") 10999 11000 rows_cpu = torch.tensor([[0, 0], [3, 3]]) 11001 rows_mps = rows_cpu.detach().clone().to("mps") 11002 11003 cols_cpu = torch.tensor([[0, 2], [0, 2]]) 11004 cols_mps = cols_cpu.detach().clone().to("mps") 11005 11006 res_cpu = x_cpu[rows_cpu, cols_cpu] 11007 res_mps = x_mps[rows_mps, cols_mps] 11008 11009 self.assertEqual(res_cpu, res_mps, str(dtype)) 11010 [helper(dtype) for dtype in self.supported_dtypes] 11011 11012 # FIXME: uint8 fails for this testcase, needs further debugging 11013 def test_slicing_using_advanced_index_for_column(self): 11014 def helper(dtype): 11015 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) 11016 x_mps = x_cpu.detach().clone().to("mps") 11017 11018 z_cpu = x_cpu[1:4, 1:3] 11019 z_mps = x_mps[1:4, 1:3] 11020 self.assertEqual(z_cpu, z_mps, str(dtype)) 11021 11022 # using advanced index for column 11023 y_cpu = x_cpu[1:4, [1, 2]] 11024 y_mps = x_mps[1:4, [1, 2]] 11025 self.assertEqual(y_cpu, y_mps, str(dtype)) 11026 # FIXME: use supported_dtypes once uint8 is fixed 11027 [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]] 11028 11029 def test_boolean_array_indexing(self): 11030 def helper(dtype): 11031 x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype) 11032 x_mps = x_cpu.detach().clone().to("mps") 11033 11034 res_cpu = x_cpu[x_cpu > 5] 11035 res_mps = x_mps[x_mps > 5] 11036 11037 self.assertEqual(res_cpu, res_mps, str(dtype)) 11038 for dtype in self.supported_dtypes: 11039 # MPS support binary op with uint8 natively starting from macOS 13.0 11040 if product_version < 13.0 and dtype == torch.uint8: 11041 continue 11042 helper(dtype) 11043 11044 def test_advanced_indexing_3D_get(self): 11045 def helper(x_cpu): 11046 x_mps = x_cpu.detach().clone().to("mps") 11047 self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :]) 11048 self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :]) 11049 self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]]) 11050 11051 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4], 11052 [0.5, 0.6, 0.7, 0.8], 11053 [0.9, 1.0, 1.1, 1.2], 11054 [1.3, 1.4, 1.5, 1.6]], 11055 11056 [[2.0, 2.1, 2.2, 2.3], 11057 [2.4, 2.5, 2.6, 2.7], 11058 [2.8, 2.9, 3.0, 3.1], 11059 [3.2, 3.3, 3.4, 3.5]], 11060 11061 [[4.0, 4.1, 4.2, 4.3], 11062 [4.4, 4.5, 4.6, 4.7], 11063 [4.8, 4.9, 5.0, 5.1], 11064 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32) 11065 helper(x_cpu) 11066 for idx in range(len(self.supported_np_dtypes)): 11067 # torch.randn / torch.rand don't work with all dtypes 11068 # Generate input data for all dtypes on Numpy them move to torch 11069 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx]) 11070 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx]) 11071 11072 helper(inputCPU) 11073 11074 def test_advanced_indexing_3D_put(self): 11075 def helper(x_cpu): 11076 dtype = x_cpu.dtype 11077 x_mps = x_cpu.detach().clone().to("mps") 11078 11079 out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu") 11080 out_tensor_cpu_view = out_tensor_cpu[1:] 11081 11082 out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps") 11083 out_tensor_mps_view = out_tensor_mps[1:] 11084 11085 x_cpu[[1, 2], 3, :] = out_tensor_cpu_view 11086 x_mps[[1, 2], 3, :] = out_tensor_mps_view 11087 self.assertEqual(x_cpu, x_mps) 11088 11089 x_cpu[[0, 2], :, :] = out_tensor_cpu_view 11090 x_mps[[0, 2], :, :] = out_tensor_mps_view 11091 self.assertEqual(x_cpu, x_mps) 11092 11093 x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view 11094 x_mps[:, [1, 0], [1]] = out_tensor_mps_view 11095 self.assertEqual(x_cpu, x_mps) 11096 11097 x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4], 11098 [0.5, 0.6, 0.7, 0.8], 11099 [0.9, 1.0, 1.1, 1.2], 11100 [1.3, 1.4, 1.5, 1.6]], 11101 11102 [[2.0, 2.1, 2.2, 2.3], 11103 [2.4, 2.5, 2.6, 2.7], 11104 [2.8, 2.9, 3.0, 3.1], 11105 [3.2, 3.3, 3.4, 3.5]], 11106 11107 [[4.0, 4.1, 4.2, 4.3], 11108 [4.4, 4.5, 4.6, 4.7], 11109 [4.8, 4.9, 5.0, 5.1], 11110 [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32) 11111 helper(x_cpu) 11112 for idx in range(len(self.supported_np_dtypes)): 11113 # torch.randn / torch.rand don't work with all dtypes 11114 # Generate input data for all dtypes on Numpy them move to torch 11115 input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx]) 11116 inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx]) 11117 11118 helper(inputCPU) 11119 11120 def test_index_put_with_view_indices(self): 11121 def helper(dtype): 11122 target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype) 11123 target_mps = torch.zeros([5, 3], device="mps", dtype=dtype) 11124 11125 indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu") 11126 indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps") 11127 11128 value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype) 11129 value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype) 11130 11131 target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True) 11132 target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True) 11133 11134 self.assertEqual(target_cpu, target_mps) 11135 11136 [helper(dtype) for dtype in [torch.int32, torch.float]] 11137 11138 # tests from 'test_indexing.py' 11139 def test_advancedindex_big(self, device="mps"): 11140 reference = torch.arange(0, 123344, dtype=torch.int, device=device) 11141 11142 self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], 11143 torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int)) 11144 11145 def test_set_item_to_scalar_tensor(self, device="mps"): 11146 m = random.randint(1, 10) 11147 n = random.randint(1, 10) 11148 z = torch.randn([m, n], device=device) 11149 a = 1.0 11150 w = torch.tensor(a, requires_grad=True, device=device) 11151 z[:, 0] = w 11152 z.sum().backward() 11153 self.assertEqual(w.grad, m * a) 11154 11155 def test_single_int(self, device="mps"): 11156 v = torch.randn(5, 7, 3, device=device) 11157 self.assertEqual(v[4].shape, (7, 3)) 11158 11159 def test_multiple_int(self, device="mps"): 11160 v = torch.randn(5, 7, 3, device=device) 11161 self.assertEqual(v[4].shape, (7, 3)) 11162 self.assertEqual(v[4, :, 1].shape, (7,)) 11163 11164 def test_none(self, device="mps"): 11165 v = torch.randn(5, 7, 3, device=device) 11166 self.assertEqual(v[None].shape, (1, 5, 7, 3)) 11167 self.assertEqual(v[:, None].shape, (5, 1, 7, 3)) 11168 self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3)) 11169 self.assertEqual(v[..., None].shape, (5, 7, 3, 1)) 11170 11171 def test_step(self, device="mps"): 11172 v = torch.arange(10, device=device) 11173 self.assertEqual(v[::1], v) 11174 self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8]) 11175 self.assertEqual(v[::3].tolist(), [0, 3, 6, 9]) 11176 self.assertEqual(v[::11].tolist(), [0]) 11177 self.assertEqual(v[1:6:2].tolist(), [1, 3, 5]) 11178 11179 def test_step_assignment(self, device="mps"): 11180 v = torch.zeros(4, 4, device=device) 11181 v[0, 1::2] = torch.tensor([3., 4.], device=device) 11182 self.assertEqual(v[0].tolist(), [0, 3, 0, 4]) 11183 self.assertEqual(v[1:].sum(), 0) 11184 11185 def test_bool_indices(self, device="mps"): 11186 v = torch.randn(5, 7, 3, device=device) 11187 boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device) 11188 self.assertEqual(v[boolIndices].shape, (3, 7, 3)) 11189 self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]])) 11190 11191 v = torch.tensor([True, False, True], dtype=torch.bool, device=device) 11192 boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device) 11193 uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device) 11194 with warnings.catch_warnings(record=True) as w: 11195 self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape) 11196 self.assertEqual(v[boolIndices], v[uint8Indices]) 11197 self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device)) 11198 self.assertEqual(len(w), 2) 11199 11200 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 11201 def test_bool_indices_accumulate(self, device="mps"): 11202 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device) 11203 mask = mask > 0 11204 y = torch.ones(size=(10, 10), device=device) 11205 y.index_put_((mask, ), y[mask], accumulate=True) 11206 self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 11207 11208 def test_multiple_bool_indices(self, device="mps"): 11209 v = torch.randn(5, 7, 3, device=device) 11210 # note: these broadcast together and are transposed to the first dim 11211 mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device) 11212 mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) 11213 self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 11214 11215 def test_byte_mask(self, device="mps"): 11216 v = torch.randn(5, 7, 3, device=device) 11217 mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 11218 with warnings.catch_warnings(record=True) as w: 11219 self.assertEqual(v[mask].shape, (3, 7, 3)) 11220 self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]])) 11221 self.assertEqual(len(w), 2) 11222 11223 v = torch.tensor([1.], device=device) 11224 self.assertEqual(v[v == 0], torch.tensor([], device=device)) 11225 11226 def test_byte_mask_accumulate(self, device="mps"): 11227 mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device) 11228 y = torch.ones(size=(10, 10), device=device) 11229 with warnings.catch_warnings(record=True) as w: 11230 warnings.simplefilter("always") 11231 y.index_put_((mask, ), y[mask], accumulate=True) 11232 self.assertEqual(y, torch.ones(size=(10, 10), device=device)) 11233 self.assertEqual(len(w), 2) 11234 11235 def test_index_put_accumulate_expanded_values(self, device="mps"): 11236 t = torch.zeros((5, 2)) 11237 t_dev = t.to(device) 11238 indices = [ 11239 torch.tensor([0, 1, 2, 3]), 11240 torch.tensor([1, ]), 11241 ] 11242 indices_dev = [i.to(device) for i in indices] 11243 values0d = torch.tensor(1.0) 11244 values1d = torch.tensor([1.0, ]) 11245 11246 out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True) 11247 out_cpu = t.index_put_(indices, values0d, accumulate=True) 11248 self.assertEqual(out_mps.cpu(), out_cpu) 11249 11250 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 11251 out_cpu = t.index_put_(indices, values1d, accumulate=True) 11252 self.assertEqual(out_mps.cpu(), out_cpu) 11253 11254 t = torch.zeros(4, 3, 2) 11255 t_dev = t.to(device) 11256 11257 indices = [ 11258 torch.tensor([0, ]), 11259 torch.arange(3)[:, None], 11260 torch.arange(2)[None, :], 11261 ] 11262 indices_dev = [i.to(device) for i in indices] 11263 values1d = torch.tensor([-1.0, -2.0]) 11264 values2d = torch.tensor([[-1.0, -2.0], ]) 11265 11266 out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True) 11267 out_cpu = t.index_put_(indices, values1d, accumulate=True) 11268 self.assertEqual(out_mps.cpu(), out_cpu) 11269 11270 out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True) 11271 out_cpu = t.index_put_(indices, values2d, accumulate=True) 11272 self.assertEqual(out_mps.cpu(), out_cpu) 11273 11274 def test_index_put_accumulate_non_contiguous(self, device="mps"): 11275 t = torch.zeros((5, 2, 2)) 11276 t_dev = t.to(device) 11277 t1 = t_dev[:, 0, :] 11278 t2 = t[:, 0, :] 11279 self.assertFalse(t1.is_contiguous()) 11280 self.assertFalse(t2.is_contiguous()) 11281 11282 indices = [torch.tensor([0, 1]), ] 11283 indices_dev = [i.to(device) for i in indices] 11284 value = torch.randn(2, 2) 11285 out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True) 11286 out_cpu = t2.index_put_(indices, value, accumulate=True) 11287 self.assertFalse(t1.is_contiguous()) 11288 self.assertFalse(t2.is_contiguous()) 11289 11290 self.assertEqual(out_mps.cpu(), out_cpu) 11291 11292 def test_index_put_accumulate_with_optional_tensors(self, device="mps"): 11293 # TODO: replace with a better solution. 11294 # Currently, here using torchscript to put None into indices. 11295 # on C++ it gives indices as a list of 2 optional tensors: first is null and 11296 # the second is a valid tensor. 11297 @torch.jit.script 11298 def func(x, i, v): 11299 idx = [None, i] 11300 x.index_put_(idx, v, accumulate=True) 11301 return x 11302 11303 n = 4 11304 t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2) 11305 t_dev = t.to(device) 11306 indices = torch.tensor([1, 0]) 11307 indices_dev = indices.to(device) 11308 value0d = torch.tensor(10.0) 11309 value1d = torch.tensor([1.0, 2.0]) 11310 11311 out_mps = func(t_dev, indices_dev, value0d.to("mps")) 11312 out_cpu = func(t, indices, value0d) 11313 self.assertEqual(out_mps.cpu(), out_cpu) 11314 11315 out_mps = func(t_dev, indices_dev, value1d.to("mps")) 11316 out_cpu = func(t, indices, value1d) 11317 self.assertEqual(out_mps.cpu(), out_cpu) 11318 11319 def test_index_put_accumulate_duplicate_indices(self, device="mps"): 11320 for i in range(1, 128): 11321 # generate indices by random walk, this will create indices with 11322 # lots of duplicates interleaved with each other 11323 delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1) 11324 11325 indices = delta.cumsum(0).long().to("mps") 11326 11327 # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it 11328 input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device) 11329 values = torch.randn(indices.size(0), device=device) 11330 output = input.index_put((indices,), values, accumulate=True) 11331 11332 input_list = input.tolist() 11333 indices_list = indices.tolist() 11334 values_list = values.tolist() 11335 for i, v in zip(indices_list, values_list): 11336 input_list[i] += v 11337 11338 self.assertEqual(output, input_list) 11339 11340 def test_index_put_deterministic(self, device="mps"): 11341 def helper(dtype, accumulate, deterministic, num_tests=128): 11342 acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype) 11343 non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype) 11344 t_idx = torch.tensor( 11345 [0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2, 11346 0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2] 11347 ) 11348 for _ in range(num_tests): 11349 try: 11350 torch.use_deterministic_algorithms(deterministic) 11351 t = torch.zeros(3, dtype=dtype, device=device) 11352 t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate) 11353 if accumulate: 11354 self.assertEqual(t, acc_expected) 11355 else: 11356 self.assertEqual(t, non_acc_expected) 11357 finally: 11358 torch.use_deterministic_algorithms(False) 11359 11360 for accumulate, deterministic in product((False, True), (False, True)): 11361 dtype = torch.float if accumulate else torch.long 11362 if not accumulate and not deterministic: 11363 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"): 11364 helper(dtype, accumulate, deterministic) 11365 else: 11366 helper(dtype, accumulate, deterministic) 11367 11368 def test_multiple_byte_mask(self, device="mps"): 11369 v = torch.randn(5, 7, 3, device=device) 11370 # note: these broadcast together and are transposed to the first dim 11371 mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) 11372 mask2 = torch.ByteTensor([1, 1, 1]).to(device) 11373 with warnings.catch_warnings(record=True) as w: 11374 warnings.simplefilter("always") 11375 self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) 11376 self.assertEqual(len(w), 2) 11377 11378 def test_byte_mask2d(self, device="mps"): 11379 v = torch.randn(5, 7, 3, device=device) 11380 c = torch.randn(5, 7, device=device) 11381 num_ones = (c > 0).sum() 11382 r = v[c > 0] 11383 self.assertEqual(r.shape, (num_ones, 3)) 11384 11385 def test_jit_indexing(self, device="mps"): 11386 def fn1(x): 11387 x[x < 50] = 1.0 11388 return x 11389 11390 def fn2(x): 11391 x[0:50] = 1.0 11392 return x 11393 11394 scripted_fn1 = torch.jit.script(fn1) 11395 scripted_fn2 = torch.jit.script(fn2) 11396 data = torch.arange(100, device=device, dtype=torch.float) 11397 out = scripted_fn1(data.detach().clone()) 11398 ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float) 11399 self.assertEqual(out, ref) 11400 out = scripted_fn2(data.detach().clone()) 11401 self.assertEqual(out, ref) 11402 11403 def test_int_indices(self, device="mps"): 11404 v = torch.randn(5, 7, 3, device=device) 11405 self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3)) 11406 self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3)) 11407 self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3)) 11408 11409 def test_index_put_src_datatype(self): 11410 def helper(device, dtype): 11411 src = torch.ones(3, 2, 4, device=device, dtype=dtype) 11412 vals = torch.ones(3, 2, 4, device=device, dtype=dtype) 11413 indices = (torch.tensor([0, 2, 1]),) 11414 res = src.index_put_(indices, vals, accumulate=True) 11415 self.assertEqual(res.shape, src.shape) 11416 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]] 11417 11418 @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12") 11419 def test_index_src_datatype(self): 11420 def helper(device, dtype): 11421 orig_dtype = dtype 11422 if dtype is torch.bool: 11423 dtype = torch.uint8 11424 11425 src = torch.ones(3, 2, 4, device=device, dtype=dtype) 11426 if orig_dtype is torch.bool: 11427 src = src == 1 11428 # test index 11429 res = src[[0, 2, 1], :, :] 11430 self.assertEqual(res.shape, src.shape) 11431 # test index_put, no accum 11432 src[[0, 2, 1], :, :] = res 11433 self.assertEqual(res.shape, src.shape) 11434 [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]] 11435 11436 def test_int_indices2d(self, device="mps"): 11437 # From the NumPy indexing example 11438 x = torch.arange(0, 12, device=device).view(4, 3) 11439 rows = torch.tensor([[0, 0], [3, 3]], device=device) 11440 columns = torch.tensor([[0, 2], [0, 2]], device=device) 11441 self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]]) 11442 11443 def test_int_indices_broadcast(self, device="mps"): 11444 # From the NumPy indexing example 11445 x = torch.arange(0, 12, device=device).view(4, 3) 11446 rows = torch.tensor([0, 3], device=device) 11447 columns = torch.tensor([0, 2], device=device) 11448 result = x[rows[:, None], columns] 11449 self.assertEqual(result.tolist(), [[0, 2], [9, 11]]) 11450 11451 def test_empty_index(self, device="mps"): 11452 x = torch.arange(0, 12, device=device).view(4, 3) 11453 idx = torch.tensor([], dtype=torch.long, device=device) 11454 self.assertEqual(x[idx].numel(), 0) 11455 11456 # empty assignment should have no effect but not throw an exception 11457 y = x.clone() 11458 y[idx] = -1 11459 self.assertEqual(x, y) 11460 11461 mask = torch.zeros(4, 3, device=device).bool() 11462 y[mask] = -1 11463 self.assertEqual(x, y) 11464 11465 def test_empty_ndim_index(self, device="mps"): 11466 x = torch.randn(5, device=device) 11467 self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)]) 11468 11469 x = torch.randn(2, 3, 4, 5, device=device) 11470 self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device), 11471 x[:, torch.empty(0, 6, dtype=torch.int64, device=device)]) 11472 11473 x = torch.empty(10, 0, device=device) 11474 self.assertEqual(x[[1, 2]].shape, (2, 0)) 11475 self.assertEqual(x[[], []].shape, (0,)) 11476 with self.assertRaisesRegex(IndexError, 'for dimension with size 0'): 11477 x[:, [0, 1]] 11478 11479 def test_empty_ndim_index_bool(self, device="mps"): 11480 x = torch.randn(5, device=device) 11481 self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)]) 11482 11483 def test_empty_slice(self, device="mps"): 11484 x = torch.randn(2, 3, 4, 5, device=device) 11485 y = x[:, :, :, 1] 11486 z = y[:, 1:1, :] 11487 self.assertEqual((2, 0, 4), z.shape) 11488 # this isn't technically necessary, but matches NumPy stride calculations. 11489 self.assertEqual((60, 20, 5), z.stride()) 11490 self.assertTrue(z.is_contiguous()) 11491 11492 def test_index_getitem_copy_bools_slices(self, device="mps"): 11493 true = torch.tensor(1, dtype=torch.uint8, device=device) 11494 false = torch.tensor(0, dtype=torch.uint8, device=device) 11495 11496 tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)] 11497 11498 for a in tensors: 11499 self.assertNotEqual(a.data_ptr(), a[True].data_ptr()) 11500 self.assertEqual(torch.empty(0, *a.shape), a[False]) 11501 self.assertNotEqual(a.data_ptr(), a[true].data_ptr()) 11502 self.assertEqual(torch.empty(0, *a.shape), a[false]) 11503 self.assertEqual(a.data_ptr(), a[None].data_ptr()) 11504 self.assertEqual(a.data_ptr(), a[...].data_ptr()) 11505 11506 def test_index_setitem_bools_slices(self, device="mps"): 11507 true = torch.tensor(1, dtype=torch.uint8, device=device) 11508 false = torch.tensor(0, dtype=torch.uint8, device=device) 11509 11510 tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)] 11511 11512 for a in tensors: 11513 # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s 11514 # (some of these ops already prefix a 1 to the size) 11515 neg_ones = torch.ones_like(a) * -1 11516 neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0) 11517 a[True] = neg_ones_expanded 11518 self.assertEqual(a, neg_ones) 11519 a[False] = 5 11520 self.assertEqual(a, neg_ones) 11521 a[true] = neg_ones_expanded * 2 11522 self.assertEqual(a, neg_ones * 2) 11523 a[false] = 5 11524 self.assertEqual(a, neg_ones * 2) 11525 a[None] = neg_ones_expanded * 3 11526 self.assertEqual(a, neg_ones * 3) 11527 a[...] = neg_ones_expanded * 4 11528 self.assertEqual(a, neg_ones * 4) 11529 if a.dim() == 0: 11530 with self.assertRaises(IndexError): 11531 a[:] = neg_ones_expanded * 5 11532 11533 def test_index_scalar_with_bool_mask(self, device="mps"): 11534 a = torch.tensor(1, device=device) 11535 uintMask = torch.tensor(True, dtype=torch.uint8, device=device) 11536 boolMask = torch.tensor(True, dtype=torch.bool, device=device) 11537 self.assertEqual(a[uintMask], a[boolMask]) 11538 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 11539 11540 a = torch.tensor(True, dtype=torch.bool, device=device) 11541 self.assertEqual(a[uintMask], a[boolMask]) 11542 self.assertEqual(a[uintMask].dtype, a[boolMask].dtype) 11543 11544 def test_setitem_expansion_error(self, device="mps"): 11545 true = torch.tensor(True, device=device) 11546 a = torch.randn(2, 3, device=device) 11547 # check prefix with non-1s doesn't work 11548 a_expanded = a.expand(torch.Size([5, 1]) + a.size()) 11549 # NumPy: ValueError 11550 with self.assertRaises(RuntimeError): 11551 a[True] = a_expanded 11552 with self.assertRaises(RuntimeError): 11553 a[true] = a_expanded 11554 11555 def test_getitem_scalars(self, device="mps"): 11556 zero = torch.tensor(0, dtype=torch.int64, device=device) 11557 one = torch.tensor(1, dtype=torch.int64, device=device) 11558 11559 # non-scalar indexed with scalars 11560 a = torch.randn(2, 3, device=device) 11561 self.assertEqual(a[0], a[zero]) 11562 self.assertEqual(a[0][1], a[zero][one]) 11563 self.assertEqual(a[0, 1], a[zero, one]) 11564 self.assertEqual(a[0, one], a[zero, 1]) 11565 11566 # indexing by a scalar should slice (not copy) 11567 self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr()) 11568 self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr()) 11569 self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr()) 11570 11571 # scalar indexed with scalar 11572 r = torch.randn((), device=device) 11573 with self.assertRaises(IndexError): 11574 r[:] 11575 with self.assertRaises(IndexError): 11576 r[zero] 11577 self.assertEqual(r, r[...]) 11578 11579 def test_setitem_scalars(self, device="mps"): 11580 zero = torch.tensor(0, dtype=torch.int64) 11581 11582 # non-scalar indexed with scalars 11583 a = torch.randn(2, 3, device=device) 11584 a_set_with_number = a.clone() 11585 a_set_with_scalar = a.clone() 11586 b = torch.randn(3, device=device) 11587 11588 a_set_with_number[0] = b 11589 a_set_with_scalar[zero] = b 11590 self.assertEqual(a_set_with_number, a_set_with_scalar) 11591 a[1, zero] = 7.7 11592 self.assertEqual(7.7, a[1, 0]) 11593 11594 # scalar indexed with scalars 11595 r = torch.randn((), device=device) 11596 with self.assertRaises(IndexError): 11597 r[:] = 8.8 11598 with self.assertRaises(IndexError): 11599 r[zero] = 8.8 11600 r[...] = 9.9 11601 self.assertEqual(9.9, r) 11602 11603 def test_basic_advanced_combined(self, device="mps"): 11604 # From the NumPy indexing example 11605 x = torch.arange(0, 12, device=device).view(4, 3) 11606 self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]]) 11607 self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]]) 11608 11609 # Check that it is a copy 11610 unmodified = x.clone() 11611 x[1:2, [1, 2]].zero_() 11612 self.assertEqual(x, unmodified) 11613 11614 # But assignment should modify the original 11615 unmodified = x.clone() 11616 x[1:2, [1, 2]] = 0 11617 self.assertNotEqual(x, unmodified) 11618 11619 def test_int_assignment(self, device="mps"): 11620 x = torch.arange(0, 4, device=device).view(2, 2) 11621 x[1] = 5 11622 self.assertEqual(x.tolist(), [[0, 1], [5, 5]]) 11623 11624 x = torch.arange(0, 4, device=device).view(2, 2) 11625 x[1] = torch.arange(5, 7, device=device) 11626 self.assertEqual(x.tolist(), [[0, 1], [5, 6]]) 11627 11628 def test_byte_tensor_assignment(self, device="mps"): 11629 x = torch.arange(0., 16, device=device).view(4, 4) 11630 b = torch.ByteTensor([True, False, True, False]).to(device) 11631 value = torch.tensor([3., 4., 5., 6.], device=device) 11632 11633 with warnings.catch_warnings(record=True) as w: 11634 x[b] = value 11635 self.assertEqual(len(w), 1) 11636 11637 self.assertEqual(x[0], value) 11638 self.assertEqual(x[1], torch.arange(4., 8, device=device)) 11639 self.assertEqual(x[2], value) 11640 self.assertEqual(x[3], torch.arange(12., 16, device=device)) 11641 11642 def test_variable_slicing(self, device="mps"): 11643 x = torch.arange(0, 16, device=device).view(4, 4) 11644 indices = torch.IntTensor([0, 1]).to(device) 11645 i, j = indices 11646 self.assertEqual(x[i:j], x[0:1]) 11647 11648 def test_ellipsis_tensor(self, device="mps"): 11649 x = torch.arange(0, 9, device=device).view(3, 3) 11650 idx = torch.tensor([0, 2], device=device) 11651 self.assertEqual(x[..., idx].tolist(), [[0, 2], 11652 [3, 5], 11653 [6, 8]]) 11654 self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2], 11655 [6, 7, 8]]) 11656 11657 def test_invalid_index(self, device="mps"): 11658 x = torch.arange(0, 16, device=device).view(4, 4) 11659 self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"]) 11660 11661 def test_out_of_bound_index(self, device="mps"): 11662 x = torch.arange(0, 100, device=device).view(2, 5, 10) 11663 self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5]) 11664 self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5]) 11665 self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10', 11666 lambda: x[0, 1, 15]) 11667 self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10', 11668 lambda: x[:, :, 12]) 11669 11670 def test_zero_dim_index(self, device="mps"): 11671 x = torch.tensor(10, device=device) 11672 self.assertEqual(x, x.item()) 11673 11674 def runner(): 11675 print(x[0]) 11676 return x[0] 11677 11678 self.assertRaisesRegex(IndexError, 'invalid index', runner) 11679 11680 def test_cpu_indices(self, device="mps"): 11681 idx = torch.tensor([0, 1]) 11682 b = torch.zeros(2, device=device) 11683 x = torch.ones(10, device=device) 11684 x[idx] = b # index_put_ 11685 ref = torch.ones(10, device=device) 11686 ref[:2] = 0 11687 self.assertEqual(x, ref, atol=0, rtol=0) 11688 out = x[idx] # index 11689 self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0) 11690 11691 def test_nextafter(self, device="mps"): 11692 for dtype in [torch.float16, torch.float32]: 11693 x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype) 11694 y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype) 11695 na = torch.nextafter(x, y) 11696 na_cpu = torch.nextafter(x.cpu(), y.cpu()) 11697 na_ge_x_mps = na.cpu() > x.cpu() 11698 # greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051 11699 na_ge_x_cpu = na_cpu > x.cpu() 11700 self.assertEqual(na_ge_x_mps, na_ge_x_cpu) 11701 11702 11703class TestRNNMPS(TestCaseMPS): 11704 def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False, 11705 seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False): 11706 rnn = nn.LSTM( 11707 input_size=input_size, 11708 hidden_size=hidden_size, 11709 num_layers=num_layers, 11710 bias=bias, 11711 bidirectional=bidirectional, 11712 batch_first=batch_first, 11713 device="cpu" 11714 ) 11715 bidirectional_mul = 2 if bidirectional else 1 11716 11717 if batch_first: 11718 input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward) 11719 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11720 requires_grad=backward) 11721 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11722 requires_grad=backward) 11723 else: 11724 input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward) 11725 hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11726 requires_grad=backward) 11727 cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype, 11728 requires_grad=backward) 11729 11730 cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx)) 11731 11732 rnn = rnn.to(device) 11733 input = input.to(device) 11734 hx = hx.to(device) 11735 cx = cx.to(device) 11736 output, (hn, cn) = rnn(input, (hx, cx)) 11737 11738 self.assertEqual(cpu_output, output) 11739 self.assertEqual(cpu_hn, hn) 11740 self.assertEqual(cpu_cn, cn) 11741 11742 def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True): 11743 rnn = rnn.to(device) 11744 inp, hx, cx = inp.to(device), hx.to(device), cx.to(device) 11745 11746 output, (hx_out, cx_out) = rnn(inp, (hx, cx)) 11747 assert output_grad_presented or states_grad_presented, "At least some outputs must be used" 11748 11749 f = 0 11750 if output_grad_presented: 11751 f = f + 3 * output.sum() 11752 if states_grad_presented: 11753 f = f + (hx_out * cx_out).sum() 11754 11755 param_names, params = zip(*rnn.named_parameters()) 11756 param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True)) 11757 11758 input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx]) 11759 return output, param_grads, input_grad, hx_grad, cx_grad 11760 11761 if backward: 11762 grad_cases = [ 11763 dict(output_grad_presented=True, states_grad_presented=True), 11764 dict(output_grad_presented=False, states_grad_presented=True), 11765 dict(output_grad_presented=True, states_grad_presented=False), 11766 ] 11767 11768 for grad_case in grad_cases: 11769 cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\ 11770 get_backward_results(rnn, "cpu", input, hx, cx, **grad_case) 11771 mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\ 11772 get_backward_results(rnn, device, input, hx, cx, **grad_case) 11773 11774 self.assertEqual(cpu_hx_grad, mps_hx_grad) 11775 self.assertEqual(cpu_cx_grad, mps_cx_grad) 11776 self.assertEqual(cpu_output, mps_output) 11777 self.assertEqual(cpu_input_grad, mps_input_grad) 11778 for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad): 11779 self.assertEqual(cpu_weight_grad, mps_weight_grad, 11780 f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}") 11781 11782 LSTM_TEST_CASES = [ 11783 {}, # default 11784 dict(batch_first=True), 11785 dict(bias=False), 11786 dict(bidirectional=True), 11787 dict(batch_first=True, bias=False), 11788 dict(bidirectional=True, bias=False), 11789 dict(bidirectional=True, batch_first=True), 11790 dict(bidirectional=True, batch_first=True, bias=False) 11791 ] 11792 11793 def test_lstm_forward(self, device="mps", dtype=torch.float32): 11794 for num_layers in [1, 2, 5]: 11795 for test_options in self.LSTM_TEST_CASES: 11796 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options) 11797 11798 def test_lstm_backward(self, device="mps", dtype=torch.float32): 11799 for num_layers in [1, 2, 5]: 11800 for test_options in self.LSTM_TEST_CASES: 11801 self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options) 11802 11803 def test_RNN_cell_no_broadcasting(self): 11804 def test(cell_module, input, hx, input_size, hidden_size): 11805 cell = cell_module(input_size, hidden_size, device='mps') 11806 self.assertRaises(RuntimeError, lambda: cell(input, hx)) 11807 11808 def test_all(hidden_size, bad_hx, good_hx, input_size, input): 11809 test(nn.RNNCell, input, bad_hx, input_size, hidden_size) 11810 test(nn.GRUCell, input, bad_hx, input_size, hidden_size) 11811 test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size) 11812 test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size) 11813 11814 hidden_size = 20 11815 input_size = 10 11816 input = torch.randn(3, input_size, device='mps') 11817 bad_hx = torch.randn(1, hidden_size, device='mps') 11818 good_hx = torch.randn(3, hidden_size, device='mps') 11819 11820 # Test hidden/input batch size broadcasting 11821 test_all(hidden_size, bad_hx, good_hx, input_size, input) 11822 11823 # Test hx's hidden_size vs module's hidden_size broadcasting 11824 bad_hx = torch.randn(3, 1) 11825 test_all(hidden_size, bad_hx, good_hx, input_size, input) 11826 11827 # Test input's input_size vs module's input_size broadcasting 11828 bad_input = torch.randn(3, 1) 11829 test_all(hidden_size, good_hx, good_hx, input_size, bad_input) 11830 11831 def test_LSTM_cell(self): 11832 # this is just a smoke test; these modules are implemented through 11833 # autograd so no Jacobian test is needed 11834 for bias in (True, False): 11835 input = torch.randn(3, 10, device='mps') 11836 hx = torch.randn(3, 20, device='mps') 11837 cx = torch.randn(3, 20, device='mps') 11838 lstm = nn.LSTMCell(10, 20, bias=bias, device='mps') 11839 for _ in range(6): 11840 hx, cx = lstm(input, (hx, cx)) 11841 11842 (hx + cx).sum().backward() 11843 11844 def test_LSTM_cell_forward_input_size(self): 11845 input = torch.randn(3, 11, device='mps') 11846 hx = torch.randn(3, 20, device='mps') 11847 cx = torch.randn(3, 20, device='mps') 11848 lstm = nn.LSTMCell(10, 20, device='mps') 11849 self.assertRaises(Exception, lambda: lstm(input, (hx, cx))) 11850 11851 def test_LSTM_cell_forward_hidden_size(self): 11852 input = torch.randn(3, 10, device='mps') 11853 hx = torch.randn(3, 21, device='mps') 11854 cx = torch.randn(3, 20, device='mps') 11855 lstm = nn.LSTMCell(10, 20, device='mps') 11856 self.assertRaises(Exception, lambda: lstm(input, (hx, cx))) 11857 self.assertRaises(Exception, lambda: lstm(input, (cx, hx))) 11858 11859 11860class TestFallbackWarning(TestCase): 11861 # TODO: Remove once test_testing.py is running on MPS devices 11862 def test_no_warning_on_import(self): 11863 out = subprocess.check_output( 11864 [sys.executable, "-W", "always", "-c", "import torch"], 11865 stderr=subprocess.STDOUT, 11866 # On Windows, opening the subprocess with the default CWD makes `import torch` 11867 # fail, so just set CWD to this script's directory 11868 cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") 11869 self.assertEqual(out, "") 11870 11871 def _get_not_implemented_op(self): 11872 # This can be changed once we actually implement 'lcm' 11873 # Should return fn, args, kwargs, string_version 11874 return (torch.lcm, 11875 [torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {}, 11876 "torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))") 11877 11878 def test_error_on_not_implemented(self): 11879 fn, args, kwargs, _ = self._get_not_implemented_op() 11880 11881 with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"): 11882 fn(*args, **kwargs) 11883 11884 def test_warn_on_not_implemented_with_fallback(self): 11885 _, _, _, op = self._get_not_implemented_op() 11886 script = f""" 11887import os 11888# MUST happen before pytorch's import 11889os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 11890import warnings 11891 11892with warnings.catch_warnings(record=True) as w: 11893 import torch 11894 11895if len(w) > 0: 11896 print(w) 11897 exit(1) 11898 11899# This should run just fine and raise warning about perf 11900with warnings.catch_warnings(record=True) as w: 11901 {op} 11902 11903if len(w) != 1: 11904 print(w) 11905 exit(2) 11906""" 11907 try: 11908 subprocess.check_output( 11909 [sys.executable, '-W', 'always', '-c', script], 11910 stderr=subprocess.STDOUT, 11911 # On Windows, opening the subprocess with the default CWD makes `import torch` 11912 # fail, so just set CWD to this script's directory 11913 cwd=os.path.dirname(os.path.realpath(__file__)),) 11914 except subprocess.CalledProcessError as e: 11915 if e.returncode == 1: 11916 self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." + 11917 e.output.decode("utf-8")) 11918 elif e.returncode == 2: 11919 self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with " 11920 f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}") 11921 else: 11922 self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " + 11923 e.output.decode("utf-8")) 11924 11925class TestNoRegression(TestCase): 11926 def test_assert_close(self): 11927 a = torch.ones(1, device="mps") 11928 b = torch.zeros(1, device="mps") 11929 inf = a / b 11930 nan = b / b 11931 11932 with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): 11933 torch.testing.assert_close(a, inf) 11934 11935 # TODO: The NaN test is failing when all the tests in test_mps are run 11936 # together but passes when run separately. There seems to be memory 11937 # corruption which needs to be fixed for this test to be enabled. 11938 # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"): 11939 # torch.testing.assert_close(a, nan) 11940 11941 def test_double_error(self): 11942 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"): 11943 a = torch.ones(2, dtype=torch.float64, device="mps") 11944 11945 a = torch.ones(2, device="mps") 11946 with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"): 11947 a = a.double() 11948 11949 def test_legacy_constructor(self): 11950 a = torch.ones(2, device="mps") 11951 11952 b = a.new(1) 11953 11954 def test_serialization_map_location(self): 11955 11956 # Ensures that cpu Tensor can be loaded on mps 11957 with tempfile.NamedTemporaryFile() as f: 11958 x = torch.rand(2) 11959 torch.save(x, f) 11960 11961 f.seek(0) 11962 x2 = torch.load(f, map_location="mps") 11963 11964 self.assertEqual(x, x2) 11965 self.assertEqual(x2.device.type, "mps") 11966 11967 # Ensures that mps Tensors can be loaded on mps 11968 with tempfile.NamedTemporaryFile() as f: 11969 x = torch.rand(2, device="mps") 11970 torch.save(x, f) 11971 11972 f.seek(0) 11973 x2 = torch.load(f) 11974 11975 self.assertEqual(x, x2) 11976 self.assertEqual(x2.device.type, "mps") 11977 11978 # Ensures that mps Tensors can be loaded on cpu 11979 with tempfile.NamedTemporaryFile() as f: 11980 x = torch.rand(2, device="mps") 11981 torch.save(x, f) 11982 11983 f.seek(0) 11984 x2 = torch.load(f, map_location="cpu") 11985 11986 self.assertEqual(x, x2) 11987 self.assertEqual(x2.device.type, "cpu") 11988 11989 # Ensures that `mps:0` Tensors can be loaded on mps 11990 with tempfile.NamedTemporaryFile() as f: 11991 x = torch.rand(2, device="mps:0") 11992 torch.save(x, f) 11993 11994 f.seek(0) 11995 x2 = torch.load(f, map_location="mps:0") 11996 11997 self.assertEqual(x, x2) 11998 self.assertEqual(x2.device.type, "mps") 11999 12000 12001MPS_DTYPES = get_all_dtypes() 12002for t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]: 12003 del MPS_DTYPES[MPS_DTYPES.index(t)] 12004 12005MPS_GRAD_DTYPES = [torch.float32, torch.float16] 12006 12007 12008class TestConsistency(TestCaseMPS): 12009 # TODO: This is only used while some ops are being added. 12010 # This list should contain all ops and dtypes eventually 12011 # This can be generated automatically in the `new_mps_allowlist.txt` file 12012 # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` 12013 # You most likely do NOT want to modify this manually 12014 12015 FP16_LOW_PRECISION_LIST = { 12016 'add', 'sub', 'div', 'addcdiv', 12017 '__rdiv__', '__rmul__', 12018 'nn.functional.huber_loss', 12019 'true_divide', 'kron', 12020 'gradient', 'var', 'std', 'std_mean', 'ldexp', 12021 'linalg.vector_norm', 'lerp', 12022 'addr', 'var_mean', 12023 'var_mean_unbiased', 12024 'acosh', 'asinh', 'asin', 12025 'masked.std', 12026 'nn.functional.normalize', 12027 'nn.functional.triplet_margin_loss', 12028 'nn.functional.triplet_margin_with_distance_loss', 12029 'nn.functional.batch_norm', 12030 'nn.functional.instance_norm', 12031 'round', 'xlogy', 'addcmul', 12032 'nn.functional.cross_entropy', 12033 'nn.functional.binary_cross_entropy', 12034 'nn.functional.nll_loss', 12035 'nn.functional.max_pool2d', 12036 'nn.functional.gelu', 12037 'nn.functional.glu', 12038 '_native_batch_norm_legit', 12039 '_batch_norm_with_update', 12040 'native_batch_norm', 12041 'softmax', 12042 '_softmax_backward_data', 12043 'log_softmax', 12044 'masked.softmax', 12045 'masked.log_softmax', 12046 'masked.softmin', 12047 'nn.functional.kl_div', 12048 'nn.functional.softmin', 12049 'cross', 'linalg.cross', 12050 'prod', 'masked.prod', 12051 'nextafter', 12052 'native_layer_norm', 12053 'nn.functional.layer_norm', 12054 'nn.functional.interpolate', 12055 'nn.functional.upsample_bilinear', 12056 'nn.functional.upsample_nearest', 12057 12058 # for macOS 12 12059 'masked.normalize', 'masked.sum', 'masked.var', 12060 'outer', 12061 'sum_to_size', 'sum', 12062 'mul', 12063 'nansum', 'nanmean', 12064 'norm', 12065 } 12066 12067 FP32_LOW_PRECISION_LIST = { 12068 # conv2d and conv_transpose2d results have a very small 12069 # difference compared to CPU/CUDA, so we use lower precision on FP32 12070 'nn.functional.conv2d', 12071 'nn.functional.conv_transpose2d', 12072 'matmul', '__rmatmul__', 12073 'linalg.multi_dot', 12074 'addbmm', 12075 } 12076 12077 def _compute_tolerances(self, op, dtype): 12078 if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]: 12079 return (1e-4, 3e-5) 12080 12081 if op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16: 12082 return (1e-2, 1e-2) 12083 12084 if op.name in ['nn.functional.conv_transpose1d', 12085 'nn.functional.conv_transpose2d', 12086 'nn.functional.conv_transpose3d', 12087 '__rmatmul__', 'addbmm', 'addmv', 12088 'baddbmm', 'cov', 'matmul', 'mv'] and dtype == torch.float16: 12089 return (5e-2, 5e-2) 12090 if op.name == "masked.mean": 12091 return (7e-4, 2e-3) 12092 if op.name == "native_layer_norm": 12093 return (1e-4, 1.3e-5) 12094 if op.name in ["pow", "__rpow__"] and product_version < 13.3: 12095 # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721. 12096 # fixed in macOS 13.3+ 12097 return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6) 12098 if op.name == "nn.functional.interpolate": 12099 return (1e-3, 1e-4) 12100 if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']: 12101 # TODO: Investigate why this is needed 12102 # See https://github.com/pytorch/pytorch/issues/120237 12103 return (3e-5, 3e-5) 12104 return (None, None) 12105 12106 # Used for accept mode only 12107 NEW_ALLOW_LIST = defaultdict(list) 12108 NEW_ALLOW_LIST_GRAD = defaultdict(list) 12109 12110 @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES + [torch.complex64]) 12111 def test_output_match(self, device, dtype, op): 12112 self.assertEqual(device, "cpu") 12113 12114 def get_samples(): 12115 return op.sample_inputs( 12116 device, 12117 dtype, 12118 requires_grad=(dtype.is_floating_point or dtype.is_complex), 12119 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12120 set_seed=False, 12121 ) 12122 cpu_samples = get_samples() 12123 12124 for cpu_sample in cpu_samples: 12125 # 12126 # Forward check 12127 # 12128 mps_sample = cpu_sample.transform( 12129 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) 12130 12131 cpu_args = [cpu_sample.input] + list(cpu_sample.args) 12132 cpu_kwargs = cpu_sample.kwargs 12133 mps_args = [mps_sample.input] + list(mps_sample.args) 12134 mps_kwargs = mps_sample.kwargs 12135 12136 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only 12137 if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor): 12138 mps_args[1] = cpu_args[1] 12139 12140 cpu_out = op(*cpu_args, **cpu_kwargs) 12141 mps_out = op(*mps_args, **mps_kwargs) 12142 12143 atol, rtol = self._compute_tolerances(op, dtype) 12144 if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8: 12145 atol = 1.0 12146 rtol = 0.0 12147 12148 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) 12149 12150 12151 @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES) 12152 def test_output_grad_match(self, device, dtype, op): 12153 self.assertEqual(device, "cpu") 12154 12155 def get_samples(): 12156 return op.sample_inputs( 12157 device, 12158 dtype, 12159 requires_grad=(dtype.is_floating_point or dtype.is_complex), 12160 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12161 set_seed=False, 12162 ) 12163 cpu_samples = get_samples() 12164 12165 for cpu_sample in cpu_samples: 12166 # 12167 # Forward check 12168 # 12169 forward_failed = False 12170 mps_sample = cpu_sample.transform( 12171 lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) 12172 12173 cpu_args = [cpu_sample.input] + list(cpu_sample.args) 12174 cpu_kwargs = cpu_sample.kwargs 12175 mps_args = [mps_sample.input] + list(mps_sample.args) 12176 mps_kwargs = mps_sample.kwargs 12177 12178 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only 12179 if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor): 12180 mps_args[1] = cpu_args[1] 12181 12182 cpu_out = op(*cpu_args, **cpu_kwargs) 12183 mps_out = op(*mps_args, **mps_kwargs) 12184 12185 if op.name == "unique" and cpu_kwargs["sorted"] is False: 12186 continue 12187 12188 atol, rtol = self._compute_tolerances(op, dtype) 12189 if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16: 12190 atol = 7e-4 12191 rtol = 1.5e-3 12192 12193 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol) 12194 12195 # 12196 # Backward check 12197 # 12198 if forward_failed: 12199 # We would've failed immediately anyway, but this error is clearer 12200 # We error instead of continuing so that all_backward_pass would not be True 12201 raise RuntimeError("Forward pass already failed") 12202 12203 cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out) 12204 mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out) 12205 12206 def req_grad(t): 12207 return isinstance(t, torch.Tensor) and t.requires_grad 12208 12209 diff_cpu_out = tuple(t for t in cpu_out if req_grad(t)) 12210 diff_mps_out = tuple(t for t in mps_out if req_grad(t)) 12211 diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t)) 12212 diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t)) 12213 self.assertEqual(len(diff_cpu_out), len(diff_mps_out)) 12214 self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg)) 12215 12216 if len(diff_cpu_out) == 0: 12217 continue 12218 # rand_like does not work with certain dtypes, so cast to double and cast back 12219 cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out) 12220 mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs) 12221 12222 # Compare computed gradients with cpu given random grad_output vector 12223 # Sometimes when the derivative is 0, we just don't bother creating the graph 12224 # allow_unused is needed in those cases. 12225 cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True) 12226 mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True) 12227 12228 self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol) 12229 12230 12231class TestErrorInputs(TestCase): 12232 _ignore_not_implemented_error = True 12233 12234 @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none) 12235 def test_error_inputs(self, device, op): 12236 self.assertEqual(device, "mps:0") 12237 12238 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12239 mps_samples = op.error_inputs(device, set_seed=False) 12240 12241 for mps_sample in mps_samples: 12242 mps_sample_input = mps_sample.sample_input 12243 error_type = mps_sample.error_type 12244 error_regex = mps_sample.error_regex 12245 12246 mps_args = [mps_sample_input.input] + list(mps_sample_input.args) 12247 mps_kwargs = mps_sample_input.kwargs 12248 12249 # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only 12250 if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)): 12251 mps_args[1] = mps_args[1].cpu() 12252 12253 with self.assertRaisesRegex(error_type, error_regex): 12254 op(*mps_args, **mps_kwargs) 12255 12256class TestComplex(TestCase): 12257 def test_tensor_scalar_binops(self): 12258 # Regression test for https://github.com/pytorch/pytorch/issues/119088 12259 def to_cpu(x): 12260 return x.cpu() if isinstance(x, torch.Tensor) else x 12261 12262 # Allocate tensors on mps 12263 with torch.device("mps"): 12264 inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]] 12265 self.assertTrue(all(x.device.type == "mps" for x in inputs)) 12266 # Add scalars 12267 inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)]) 12268 12269 # Iterate over all permutations of types(int, float, complex, half) and ops (excluding div) 12270 for x, y in itertools.product(inputs, inputs): 12271 for op_name in ["__add__", "__sub__", "__mul__"]: 12272 x_cpu, y_cpu = map(to_cpu, (x, y)) 12273 res = getattr(x, op_name)(y) 12274 res_cpu = getattr(x_cpu, op_name)(y_cpu) 12275 self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}") 12276 12277 12278# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS 12279@skipIfSlowGradcheckEnv 12280class TestCommon(TestCase): 12281 exact_dtype = True 12282 12283 # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI 12284 @classmethod 12285 def tearDownClass(cls): 12286 super().tearDownClass() 12287 12288 if IS_CI: 12289 err_msg = ( 12290 "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." 12291 "This is OK for testing, but be sure to set the dtypes manually before landing your PR!" 12292 ) 12293 # Assure no opinfo entry has dynamic_dtypes 12294 filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db)) 12295 for op in filtered_ops: 12296 fmt_str = opinfo.utils.str_format_dynamic_dtype(op) 12297 err_msg += "\n" + fmt_str 12298 12299 assert len(filtered_ops) == 0, err_msg 12300 12301 # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while 12302 # MPS still requires some fairly heavy special casing in the test framework. 12303 # When MPS becomes more consistent, this can probably be merged with that test using 12304 # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened 12305 @suppress_warnings 12306 # MPS only supports float32 12307 @ops(_ref_test_ops, allowed_dtypes=(torch.float32,)) 12308 def test_numpy_ref_mps(self, device, dtype, op): 12309 # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS 12310 # does not support float64 Tensors. 12311 # A few ops are currently broken on their reference inputs, but not their sample inputs. These should 12312 # get patched up and this workaround removed. 12313 broken_on_ref_inputs = op.name in ('where',) 12314 12315 # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails 12316 inputs = ( 12317 op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs 12318 else op.sample_inputs(device, dtype, set_seed=False) 12319 ) 12320 for sample_input in inputs: 12321 self.compare_with_reference(op, op.ref, sample_input) 12322 12323 @dtypes(*get_all_dtypes()) 12324 def test_tensor_creation(self, device, dtype): 12325 def ones(device): 12326 return torch.ones((2, 2), dtype=dtype, device=device) 12327 if dtype not in MPS_DTYPES + ([torch.bfloat16, torch.complex64] if product_version > 14.0 else [torch.complex64]): 12328 with self.assertRaises(TypeError): 12329 ones(device) 12330 else: 12331 mps_tensor = ones(device) 12332 cpu_tensor = ones("cpu") 12333 self.assertEqual(mps_tensor.cpu(), cpu_tensor) 12334 12335 12336# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. 12337# This requires mps to be properly registered in the device generic test framework which is not the 12338# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342 12339# to achieve this. 12340instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu") 12341instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps") 12342instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps") 12343instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps") 12344instantiate_parametrized_tests(TestMPS) 12345 12346if __name__ == "__main__": 12347 run_tests() 12348