xref: /aosp_15_r20/external/pytorch/test/test_mps.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: mps"]
2
3import io
4import platform
5import sys
6import math
7import random
8import unittest
9import warnings
10import subprocess
11import tempfile
12import os
13import copy
14import gc
15import threading
16import torch
17import torch.nn as nn
18import torch.nn.functional as F
19import itertools
20from collections import defaultdict
21from torch import inf
22from torch.nn import Buffer, Parameter
23from torch.testing._internal import opinfo
24from torch.testing._internal.common_utils import \
25    (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
26     NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
27from torch.testing import make_tensor
28from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
29import torch.backends.mps
30from torch.distributions import Uniform, Exponential
31from functools import partial
32
33from torch.testing._internal.common_methods_invocations import (
34    op_db,
35    DecorateInfo,
36    UnaryUfuncInfo,
37    ReductionOpInfo,
38    SpectralFuncInfo,
39    BinaryUfuncInfo,
40)
41from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
42from torch.testing._internal.common_nn import NNTestCase
43from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
44import numpy as np
45import torch
46import torch.utils._pytree as pytree
47from itertools import product
48import operator
49
50test_consistency_op_db = copy.deepcopy(op_db)
51test_error_inputs_op_db = copy.deepcopy(op_db)
52
53# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
54_ref_test_ops = tuple(
55    filter(
56        lambda op: not isinstance(
57            op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
58        )
59        and op.ref is not None,
60        op_db,
61    )
62)
63
64def xfailIf(condition):
65    def wrapper(func):
66        if condition:
67            return unittest.expectedFailure(func)
68        else:
69            return func
70    return wrapper
71
72def xfailIfMacOS14_4Plus(func):
73    return unittest.expectedFailure(func) if product_version > 14.3 else func  # noqa: F821
74
75def mps_ops_grad_modifier(ops):
76    XFAILLIST_GRAD = {
77
78        # precision issues
79        'special.polygammaspecial_polygamma_n_0': [torch.float16],
80        'polygammapolygamma_n_0': [torch.float16],
81        'nn.functional.binary_cross_entropy': [torch.float16],
82
83        # Unimplemented ops
84        '__getitem__': [torch.float16],
85        '_segment_reduce': [torch.float16, torch.float32],
86        '_chunk_cat': [torch.float16, torch.float32],
87        'unfold_copy': [torch.float16, torch.float32],  # unfold_backward is not implemented
88        'unfold': [torch.float16, torch.float32],
89        'sparse.mmreduce': [torch.float32],  # csr not supported
90        'unique_consecutive': [torch.float16, torch.float32],
91        'special_modified_bessel_i0': [torch.float16, torch.float32],
92        'scalar_tensor': [torch.float16, torch.float32],
93        'cdist': [torch.float32],
94        'masked.scatter': [torch.float16, torch.float32],
95        'index_fill': [torch.float16, torch.float32],  # missing `aten::_unique`.
96        'linalg.lu_factor': [torch.float16, torch.float32],  # missing `aten::lu_unpack`.
97        'aminmax': [torch.float32, torch.float16],
98
99        # Correctness issues
100        'atanh': [torch.float32],
101
102        # Random output
103        'exponential': [torch.float16, torch.float32],
104
105        # CPU errors
106        # derivative for aten::nextafter is not implemented on CPU
107        'nextafter': None,
108        # derivative for aten::floor_divide is not implemented on CPU
109        'floor_divide': [torch.float16, torch.float32],
110        # derivative for aten::narrow_copy is not implemented on CPU
111        'narrow_copy': [torch.float16, torch.float32],
112        # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU
113        'histogramdd': [torch.float16, torch.float32],
114        # derivative for aten::histogram is not implemented
115        'histogram': [torch.float16, torch.float32],
116        # 'bool' object is not iterable
117        'allclose': [torch.float16, torch.float32],
118        'equal': [torch.float16, torch.float32],
119        # 'float' object is not iterable
120        'item': [torch.float16, torch.float32],
121        # "mse_backward_cpu_out" not implemented for 'Half'
122        'nn.functional.mse_loss': [torch.float16],
123        # "smooth_l1_backward_cpu_out" not implemented for 'Half'
124        'nn.functional.smooth_l1_loss': [torch.float16],
125        # cpu error: grad requires non-empty inputs
126        'randn': [torch.float16, torch.float32],
127        'signal.windows.bartlett': [torch.float32],
128        'signal.windows.blackman': [torch.float32],
129        'signal.windows.cosine': [torch.float32],
130        'signal.windows.exponential': [torch.float32],
131        'signal.windows.gaussian': [torch.float32],
132        'signal.windows.general_cosine': [torch.float32],
133        'signal.windows.general_hamming': [torch.float32],
134        'signal.windows.hamming': [torch.float32],
135        'signal.windows.hann': [torch.float32],
136        'signal.windows.kaiser': [torch.float32],
137        'signal.windows.nuttall': [torch.float32],
138        'eye': [torch.float16, torch.float32],
139
140        # trunc_tensor not working properly for float16
141        'divtrunc_rounding': [torch.float16],
142        'fmod': [torch.float16],
143
144        # round not working properly for float16
145        'round': [torch.float16],
146
147        # atomic operation in backward pass
148        '_unsafe_masked_index': [torch.float16],
149        '_unsafe_masked_index_put_accumulate': [torch.float16],
150    }
151
152    MACOS_12_3_XFAILLIST_GRAD = {
153        # Unsupported Border padding mode, forward pass success as fallback to cpu
154        'grid_sampler_2d': [torch.float32],
155        # Unimplemented
156        'logaddexp2': [torch.float32],
157
158    }
159
160    MACOS_BEFORE_13_3_XFAILLIST_GRAD = {
161        # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
162        'masked.softmin': [torch.float32, torch.float16],
163        'masked.softmax': [torch.float32, torch.float16],
164        'masked.log_softmax': [torch.float32, torch.float16],
165
166        # Unsupported Border padding mode, forward pass success as fallback to cpu
167        'grid_sampler_2d': [torch.float32],
168
169        # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
170        # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
171        # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
172        # Running `msort` with stable `sort` passes.
173        'msort': [torch.float16],
174    }
175
176    SKIPLIST_GRAD = {
177        'nn.functional.pairwise_distance': [torch.float16],
178        # failed assertion `destination datatype must be fp32'
179        'nn.functional.conv1d': [torch.float16],
180        'nn.functional.conv2d': [torch.float16],
181        'nn.functional.conv3d': [torch.float16],
182        'nn.functional.conv_transpose1d': [torch.float16],
183        'nn.functional.conv_transpose2d': [torch.float16],
184        'nn.functional.conv_transpose3d': [torch.float16],
185    }
186
187    MACOS_13_3_XFAILLIST_GRAD = {
188        # Same issue as `argsort` and `sort` with duplicate elements (undefined behaviour).
189        # Forward pass is passing since `msort` doesn't return the indices, just the values, which match the CPU.
190        # On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
191        # Running `msort` with stable `sort` passes.
192        'msort': [torch.float16],
193    }
194
195    ON_MPS_XFAILLIST = {
196        # Failures due to lack of implementation of downstream functions on MPS backend
197        # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
198        'linalg.matrix_rank': None,
199
200        # Exception: Caused by sample input at index 3 on MPS
201        'nn.functional.conv3d': [torch.float32],
202    }
203
204    def addDecorator(op, d) -> None:
205        op.decorators = list(op.decorators) if op.decorators is not None else []
206        op.decorators.append(d)
207
208    for op in ops:
209        key = op.name + op.variant_test_name
210        if key in XFAILLIST_GRAD:
211            addDecorator(op, DecorateInfo(
212                         unittest.expectedFailure,
213                         dtypes=XFAILLIST_GRAD[key]))
214
215        if key in SKIPLIST_GRAD:
216            addDecorator(op, DecorateInfo(
217                         unittest.skip,
218                         dtypes=SKIPLIST_GRAD[key]))
219
220        if key in ON_MPS_XFAILLIST:
221            addDecorator(op, DecorateInfo(
222                         unittest.expectedFailure,
223                         dtypes=ON_MPS_XFAILLIST[key]))
224
225        if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
226            addDecorator(op, DecorateInfo(
227                         unittest.expectedFailure,
228                         dtypes=MACOS_12_3_XFAILLIST_GRAD[key]))
229
230        if key in MACOS_BEFORE_13_3_XFAILLIST_GRAD and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
231            addDecorator(op, DecorateInfo(
232                         unittest.expectedFailure,
233                         dtypes=MACOS_BEFORE_13_3_XFAILLIST_GRAD[key]))
234
235        if key in MACOS_13_3_XFAILLIST_GRAD and (product_version >= 13.3):
236            addDecorator(op, DecorateInfo(
237                         unittest.expectedFailure,
238                         dtypes=MACOS_13_3_XFAILLIST_GRAD[key]))
239        yield op
240
241def mps_ops_modifier(ops):
242    # Supported complex OPS
243    SUPPORTED_COMPLEX_OPS = {
244        '__radd__',
245        '__rmul__',
246        '__getitem__',
247        'abs',
248        'add',
249        'alias_copy',
250        'argwhere',
251        'atleast_1d',
252        'atleast_2d',
253        'atleast_3d',
254        'as_strided',
255        'as_strided_copy',
256        'as_strided_scatter',
257        'broadcast_tensors',
258        'broadcast_to',
259        'chalf',
260        'cfloat',
261        'chunk',
262        'clone',
263        'conj',
264        'conj_physical',
265        'contiguous',
266        'diag',
267        'diag_embed',
268        'diagflat',
269        'diagonal',
270        'diagonal_copy',
271        'diagonal_scatter',
272        'dsplit',
273        'empty',
274        'empty_permuted',
275        'empty_strided',
276        'eye',
277        'exp',
278        'expand',
279        'expand_as',
280        'expand_copy',
281        'flatten',
282        'fill',
283        'full',
284        'H',
285        'hsplit',
286        'imag',
287        'index_select',
288        'isfinite',
289        'isinf',
290        'isreal',
291        'item',
292        'kron',
293        'linalg.diagonal',
294        'linalg.svd',
295        'linspace',
296        'logspace',
297        'linspacetensor_overload',
298        'logspacetensor_overload',
299        'mH',
300        'mT',
301        'masked_scatter',
302        'masked_select',
303        'meshgridlist_of_tensors',
304        'meshgridvariadic_tensors',
305        'movedim',
306        'mul',
307        'narrow',
308        'narrow_copy',
309        'nn.functional.conv1d',
310        'nn.functional.conv2d',
311        'nn.functional.conv_transpose1d',
312        'nn.functional.conv_transpose2d',
313        'nn.functional.feature_alpha_dropoutwithout_train',
314        'nn.functional.padcircular',
315        'nn.functional.tanhshrink',
316        'nn.functional.unfold',
317        'nonzero',
318        'ones',
319        'outer',
320        'permute',
321        'positive',
322        'randn',
323        'ravel',
324        'real',
325        'repeat_interleave',
326        'reshape_as',
327        'reshape',
328        'resolve_conj',
329        'resolve_neg',
330        'scalar_tensor',
331        'select',
332        'sgn',
333        'slice',
334        'split',
335        'split_with_sizes',
336        'split_with_sizes_copy',
337        'splitlist_args',
338        'squeeze',
339        'squeezemultiple',
340        'sub',
341        'svd',
342        't',
343        't_copy',
344        'tanh',
345        'tensor_split',
346        'transpose',
347        'T',
348        'unbind',
349        'unflatten',
350        'unfold',
351        'unfold_copy',
352        'unsafe_chunk',
353        'unsafe_split',
354        'unsqueeze',
355        'unsqueeze_copy',
356        'view_as',
357        'view_as_real',
358        'view',
359        'view_copy',
360        'vsplit',
361        'zero_',
362        'zeros',
363    }
364
365    AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
366        '__rdiv__',
367        '__rmatmul__',
368        '_chunk_cat',
369        '_unsafe_masked_index',
370        'acos',
371        'acosh',
372        'all',
373        'allclose',
374        'any',
375        'addcdiv',
376        'addcmul',
377        'addmmdecomposed',
378        'addmv',
379        'asin',
380        'atan',
381        'atanh',
382        'bfloat16',
383        'bmm',
384        'bool',
385        'cartesian_prod',
386        'cat',
387        'char',
388        'column_stack',
389        'combinations',
390        'corrcoef',
391        'constant_pad_nd',
392        'cos',
393        'cosh',
394        'count_nonzero',
395        'diff',
396        'div',
397        'divno_rounding_mode',
398        'dot',
399        'dstack',
400        'einsum',
401        'eq',
402        'equal',
403        'exp2',
404        'expm1',
405        'fft.fft',
406        'fft.fft2',
407        'fft.fftn',
408        'fft.fftshift',
409        'fft.ifft',
410        'fft.ifft2',
411        'fft.ifftn',
412        'fft.ifftshift',
413        'fft.irfftn',
414        'fft.irfft2',
415        'fft.irfft',
416        'fft.hfftn',
417        'fft.hfft2',
418        'fft.hfft',
419        'flip',
420        'fliplr',
421        'flipud',
422        'float',
423        'gradient',
424        'half',
425        'hstack',
426        'inner',
427        'int',
428        'isclose',
429        'isnan',
430        'ldexp',
431        'linalg.multi_dot',
432        'linalg.pinv',
433        'log10',
434        'log1p',
435        'log2',
436        'log',
437        'logical_and',
438        'logical_not',
439        'logical_or',
440        'logical_xor',
441        'logsumexp',
442        'long',
443        'masked_fill',
444        'masked.mean',
445        'masked.prod',
446        'masked.std',
447        'masked.sum',
448        'masked.var',
449        'masked.logsumexp',
450        'matmul',
451        'mean',
452        'mm',
453        'mv',
454        'ne',
455        'neg',
456        'nn.functional.padconstant',
457        'nn.functional.padreflect',
458        'nn.functional.padreplicate',
459        'nn.functional.pixel_shuffle',
460        'nn.functional.pixel_unshuffle',
461        'nn.functional.rms_norm',
462        'nn.functional.softsign',
463        'pinverse',
464        'prod',
465        'reciprocal',
466        'roll',
467        'rot90',
468        'rsqrt',
469        'short',
470        'sigmoid',
471        'sin',
472        'sinh',
473        'sqrt',
474        'square',
475        'stack',
476        'stft',
477        'sum',
478        'sum_to_size',
479        'tan',
480        'tensordot',
481        'trace',
482        'trapz',
483        'trapezoid',
484        'tril',
485        'triu',
486        'true_divide',
487        'vstack',
488        'where',
489        'byte',
490    }
491    # Those ops worked on MacOS12, but broken on MacOS13, see https://github.com/pytorch/pytorch/issues/85758
492    MACOS_12_3_XFAILLIST = {
493        # Top 60
494        # expected failures
495        # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
496        # fixed in macOS 13.3. Currently error is not raised.
497        'pow': [torch.int16, torch.int64, torch.uint8, torch.int8],
498        # expected failures
499        '__rpow__': [torch.uint8, torch.int8],
500
501        # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
502        'cdist': [torch.float32],
503        'tan': [torch.uint8, torch.float32],
504
505        # Data type support starts from macOS 13
506        'nn.functional.avg_pool1d': [torch.int64],
507        'nn.functional.avg_pool2d': [torch.int64],
508        'nn.functional.local_response_norm': [torch.int64],
509        '__radd__': [torch.uint8],
510        '__rdiv__': [torch.uint8],
511        '__rmul__': [torch.uint8],
512        'abs': [torch.uint8],
513        'acos': [torch.uint8],
514        'acosh': [torch.uint8],
515        'add': [torch.uint8],
516        'asin': [torch.uint8],
517        'asinh': [torch.uint8],
518        'atan': [torch.uint8],
519        'atanh': [torch.uint8],
520        'ceil': [torch.uint8],
521        'corrcoef': [torch.uint8],
522        'cos': [torch.uint8],
523        'cosh': [torch.uint8],
524        'cov': [torch.uint8],
525        'cumulative_trapezoid': [torch.uint8],
526        'deg2rad': [torch.uint8],
527        'diff': [torch.uint8],
528        'eq': [torch.uint8],
529        'equal': [torch.uint8],
530        'erf': [torch.uint8],
531        'exp2': [torch.uint8],
532        'exp': [torch.uint8],
533        'expm1': [torch.uint8],
534        'floor': [torch.uint8],
535        'fmax': [torch.uint8],
536        'fmin': [torch.uint8],
537        'fmod': [torch.uint8],
538        'ge': [torch.uint8],
539        'gt': [torch.uint8],
540        'isclose': [torch.uint8],
541        'isnan': [torch.uint8],
542        'kron': [torch.uint8],
543        'le': [torch.uint8],
544        'log10': [torch.uint8],
545        'log1p': [torch.uint8],
546        'log2': [torch.uint8],
547        'log': [torch.uint8],
548        'logical_and': [torch.uint8],
549        'logical_or': [torch.uint8],
550        'logical_xor': [torch.uint8],
551        'logit': [torch.uint8],
552        'lt': [torch.uint8],
553        'masked.mean': [torch.uint8],
554        'masked.std': [torch.uint8],
555        'masked.var': [torch.uint8],
556        'maximum': [torch.uint8],
557        'minimum': [torch.uint8],
558        'mul': [torch.uint8],
559        'ne': [torch.uint8],
560        'neg': [torch.uint8],
561        'nn.functional.cosine_embedding_loss': [torch.uint8],
562        'nn.functional.margin_ranking_loss': [torch.uint8],
563        'nn.functional.poisson_nll_loss': [torch.uint8],
564        'nn.functional.softsign': [torch.uint8],
565        'nn.functional.tanhshrink': [torch.uint8],
566        'nn.functional.triplet_margin_loss': [torch.uint8],
567        'nn.functional.triplet_margin_with_distance_loss': [torch.uint8],
568        'nn.functional.pairwise_distance': [torch.uint8],
569        'outer': [torch.uint8],
570        'rad2deg': [torch.uint8],
571        'reciprocal': [torch.uint8],
572        'remainder': [torch.uint8],
573        'round': [torch.uint8],
574        'rsqrt': [torch.uint8],
575        'sigmoid': [torch.uint8],
576        'sign': [torch.uint8],
577        'signbit': [torch.uint8],
578        'sin': [torch.uint8],
579        'sinh': [torch.uint8],
580        'special.ndtr': [torch.uint8],
581        'sqrt': [torch.uint8],
582        'sub': [torch.uint8],
583        'trapezoid': [torch.uint8],
584        'trapz': [torch.uint8],
585        'true_divide': [torch.uint8],
586        'trunc': [torch.uint8],
587        'xlogy': [torch.uint8],
588        'minbinary': [torch.uint8],
589        'maxbinary': [torch.uint8],
590        'divtrunc_rounding': [torch.uint8],
591        'divfloor_rounding': [torch.uint8],
592        'divno_rounding_mode': [torch.uint8],
593        'floor_divide': [torch.uint8],
594        'ldexp': [torch.uint8],
595        # square internally calls into power, and will type cast to int64, which supports starting from macOS 13
596        'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
597
598        # cpu not giving nan for x/0.0
599        'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
600
601        # inconsistency errors between cpu and mps, max seen atol is 2
602        'nn.functional.interpolatebilinear': [torch.uint8],
603    }
604
605    MACOS_BEFORE_13_3_XFAILLIST = {
606        # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
607        'tan': [torch.float32],
608        'cdist': [torch.float32],
609
610        # CPU Error: cpu not giving nan for x/0.0
611        'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
612
613        # test blow pass on macOS 12 as it falls back to cpu
614        # Argsort case using duplicate indices (undefined behaviour):
615        #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], devuce='cpu')
616        #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
617        # Elements from index 30 and 5133 are both equal.
618        # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
619        'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
620        # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
621        # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
622        'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
623        # Unsupported dtypes
624        'cumsum': [torch.int64],
625        'cumprod': [torch.int64],
626        'cumulative_trapezoid': [torch.int64],
627        'masked.cumsum': [torch.int64],
628        'masked.cumprod': [torch.int64],
629        'linalg.vander': [torch.int64],
630    }
631
632    MACOS_AFTER_13_1_XFAILLIST = {
633        # before macOS 13.2 it falls back to cpu and pass the forward pass
634        'grid_sampler_2d': [torch.float32],  # Unsupported Border padding mode
635        # inconsistency errors between cpu and mps, max seen atol is 2
636        'nn.functional.interpolatebilinear': [torch.uint8],
637    }
638
639    MACOS_13_3_XFAILLIST = {
640        # Failure due to precision issue for fp16
641        # on both cpu and mps there are test cases that might produce inf result
642        # 'nn.functional.pairwise_distance': [torch.float16],
643
644        # test blow pass on macOS 12 as it falls back to cpu
645        # Argsort case using duplicate indices (undefined behaviour):
646        #  - CPU output: tensor([2546, 6917, 3181,  ..., 7128, 5133,   30], devuce='cpu')
647        #  - MPS output: tensor([2546, 6917, 3181,  ..., 7128,   30, 5133], device='mps:0')
648        # Elements from index 30 and 5133 are both equal.
649        # Since CPU is not using argsort with stable=True, these cases result in undefined behaviour.
650        'argsort': [torch.float16, torch.int8, torch.uint8, torch.bool],
651        # Same issue as `argsort` with duplicate indices. This test checks both the sorted values and the indices.
652        # The values of the sorted tensor match the CPU, but in case of the returned indices this results in undefined behaviour.
653        'sort': [torch.int8, torch.uint8, torch.bool, torch.float16],
654    }
655
656    MACOS_BEFORE_14_4_XFAILLIST = {
657        # These ops work fine in 14.4 but fail in 14.2 or 13.x
658        'fft.hfft2': [torch.complex64],
659    }
660
661    # Those ops are not expected to work
662    UNIMPLEMENTED_XFAILLIST = {
663        # Failures due to lack of op implementation on MPS backend
664        'login': None,
665        'linalg.eig': None,
666        'linalg.eigvals': None,
667        'put': None,
668        'nn.functional.conv_transpose3d': None,
669        'rounddecimals_neg_3': None,
670        'rounddecimals_3': None,
671        'rounddecimals_0': None,
672        '__rsub__': None,
673        'angle': None,
674        'cauchy_': None,
675        'cauchy': None,
676        'cholesky': None,
677        'cholesky_inverse': None,
678        'cholesky_solve': None,
679        'cummax': None,
680        'cummin': None,
681        'erfc': None,
682        'frexp': None,
683        'gcd': None,
684        'geqrf': None,
685        'nn.functional.grid_sample': None,  # Unsupported Border padding mode
686        'heaviside': None,
687        'i0': None,
688        'igamma': None,
689        'igammac': None,
690        'index_copy': None,
691        'index_reduceprod': None,
692        'index_reducemean': None,
693        'index_reduceamax': None,
694        'index_reduceamin': None,
695        'isneginf': None,
696        'isposinf': None,
697        'kthvalue': None,
698        'lcm': None,
699        'linalg.cholesky': None,
700        'linalg.cholesky_ex': None,
701        'linalg.cond': None,
702        'linalg.detsingular': None,
703        'linalg.det': None,
704        'linalg.eigh': None,
705        'linalg.eigvalsh': None,
706        'linalg.householder_product': None,
707        'linalg.ldl_factor': None,
708        'linalg.ldl_factor_ex': None,
709        'linalg.ldl_solve': None,
710        'linalg.lstsq': None,
711        'linalg.lstsqgrad_oriented': None,
712        'linalg.lu': None,
713        'linalg.lu_factor_ex': None,
714        'linalg.lu_solve': None,
715        'linalg.matrix_norm': [torch.float32],
716        'linalg.norm': [torch.float32],
717        'linalg.normsubgradients_at_zero': [torch.float32],
718        'linalg.qr': None,
719        'linalg.slogdet': None,
720        'linalg.solve': None,
721        'linalg.solve_ex': None,
722        'linalg.svdvals': None,
723        'linalg.tensorsolve': None,
724        'linalg.vecdot': None,
725        'logcumsumexp': None,
726        'logdet': None,
727        'lu': None,
728        'lu_solve': None,
729        'lu_unpack': None,
730        'masked.median': None,
731        'matrix_exp': None,
732        'mode': None,
733        'nanmedian': None,
734        'native_dropout_backward': None,
735        'normnuc': None,
736        'nn.functional.fractional_max_pool2d': None,
737        'nn.functional.fractional_max_pool3d': None,
738        'nn.functional.adaptive_avg_pool3d': None,
739        'nn.functional.adaptive_max_pool3d': None,
740        'nn.functional.interpolatearea': None,
741        'nn.functional.interpolatebicubic': None,
742        'nn.functional.interpolatetrilinear': None,
743        'nn.functional.max_unpool1dgrad': None,
744        'nn.functional.max_unpool2dgrad': None,
745        'nn.functional.max_unpool3dgrad': None,
746        'nn.functional.avg_pool3d': None,
747        'nn.functional.ctc_loss': None,
748        'nn.functional.embedding_bag': None,
749        'nn.functional.hardshrink': None,
750        'nn.functional.max_pool3d': None,
751        'nn.functional.max_unpool1d': None,
752        'nn.functional.max_unpool2d': None,
753        'nn.functional.max_unpool3d': None,
754        'nn.functional.multi_margin_loss': None,
755        'nn.functional.multilabel_margin_loss': None,
756        'nn.functional.pdist': None,
757        'nn.functional.rrelu': None,
758        'nn.functional.norm': None,
759        'ormqr': None,
760        'pca_lowrank': None,
761        'qr': None,
762        'rsub': None,
763        'scatter_reduceamax': None,
764        'scatter_reduceamin': None,
765        'scatter_reducemin': None,
766        'scatter_reducemean': None,
767        'scatter_reduceprod': None,
768        'scatter_reducesum': None,
769        'segment_reduce': None,
770        '_segment.reduce': None,
771        'segment.reduce': None,
772        'segment_reduce_offsets': None,
773        '_segment_reduce_offsets': None,
774        '_segment_reduce_lengths': None,
775        '_segment_reducelengths': None,
776        '_segment_reduceoffsets': None,
777        'sinc': None,
778        'sparse.mm': None,
779        'sparse.mmreduce': None,
780        'special.airy_ai': None,
781        'special.bessel_j0': None,
782        'special.bessel_j1': None,
783        'special.bessel_y0': None,
784        'special.bessel_y1': None,
785        'special.chebyshev_polynomial_t': None,
786        'special.chebyshev_polynomial_u': None,
787        'special.entr': None,
788        'special.erfcx': None,
789        'special.hermite_polynomial_h': None,
790        'special.hermite_polynomial_he': None,
791        'special.i0e': None,
792        'special.i1': None,
793        'special.i1e': None,
794        'special.laguerre_polynomial_l': None,
795        'special.log_ndtr': None,
796        'special.modified_bessel_i0': None,
797        'special.modified_bessel_i1': None,
798        'special.modified_bessel_k0': None,
799        'special.modified_bessel_k1': None,
800        'special.ndtri': None,
801        'special.scaled_modified_bessel_k0': None,
802        'special.scaled_modified_bessel_k1': None,
803        'special.spherical_bessel_j0': None,
804        'special.xlog1py': None,
805        'special.zeta': None,
806        'svd_lowrank': None,
807        'symeig': None,
808        'take': None,
809        'to': None,
810        'to_sparse': None,
811        'unique': None,
812        'vdot': None,
813        'segment_reduce_': None,
814        '_upsample_bilinear2d_aa': None,
815        'geometric' : None,
816        'geometric_': None,
817        'log_normal_': None,
818        'log_normal': None,
819        'cdouble': None,
820        'double': None,
821        'nn.functional.softminwith_dtype': None,
822        'log_softmaxwith_dtype': None,
823        'softmaxwith_dtype': None,
824        'float_power': None,
825        'full_like': None,
826        'linalg.matrix_rankhermitian': None,
827        'linalg.pinvhermitian': None,
828        'nonzero_static': None,
829
830        # MPS: input sizes must be divisible by output sizes
831        'nn.functional.adaptive_avg_pool1d': None,
832        'nn.functional.adaptive_avg_pool2d': None,
833
834        # Unsupported dtypes
835        # bmm is not supported for integral types
836        'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
837        'ones_like': None,
838        'zeros_like': None,
839
840        # Convolution for integral types is not supported on MPS
841        'nn.functional.conv1d': [torch.int64],
842        'nn.functional.conv2d': [torch.int64],
843        'nn.functional.conv3d': [torch.int64],
844        'nn.functional.conv_transpose1d': [torch.int64],
845        'nn.functional.conv_transpose2d': [torch.int64],
846
847        # Unsupported dtypes
848        'dot': [torch.int64],
849        'histc': [torch.float16],
850        'index_add': [torch.int64],
851        'log1p': [torch.int64],
852        'sigmoid': [torch.int64],
853        'atan2': [torch.int64],
854
855        # GEMM on MPS is not supported for integral types
856        'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
857        '__rmatmul__': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
858        'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
859        'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
860        'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
861        'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
862        'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
863        'mm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
864        'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
865        'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
866        'inner': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
867        'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
868        'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
869        'mat': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
870        'mv': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
871        'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
872        'unravel_index': [torch.int32, torch.int64],
873
874        # new_zeros/new_ones: Cannot convert a MPS Tensor to float64 dtype as
875        # the MPS framework doesn't support float64
876        'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
877        'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
878        'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
879        # returned output on CPU is float64
880        'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
881
882        # trunc_tensor not working properly for float16
883        'divtrunc_rounding': [torch.float16],
884        'fmod': [torch.float16],
885
886        # round not working properly for float16
887        'round': [torch.float16],
888
889        # atomic operations not supported
890        '_unsafe_masked_index_put_accumulate': [torch.bool, torch.int8, torch.uint8, torch.float16, torch.int16, torch.int64],
891    }
892
893    if product_version < 14.0:
894        # FFT and BFloat16 support was added in MacOS 14
895        UNIMPLEMENTED_XFAILLIST.update({
896            'bfloat16': None,
897            'fft.fft': None,
898            'fft.fft2': None,
899            'fft.fftn': None,
900            'fft.hfft': None,
901            'fft.hfft2': None,
902            'fft.hfftn': None,
903            'fft.ifft': None,
904            'fft.ifft2': None,
905            'fft.ifftn': None,
906            'fft.ihfft': None,
907            'fft.ihfft2': None,
908            'fft.ihfftn': None,
909            'fft.irfft': None,
910            'fft.irfft2': None,
911            'fft.irfftn': None,
912            'fft.rfft': None,
913            'fft.rfft2': None,
914            'fft.rfftn': None,
915            'stft': None,
916            # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers,
917            # not reproducible in later OS. Added assert to op if used in < 14.0
918            'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8],
919            'nn.functional.max_pool2d': [torch.uint8],
920        })
921
922    if product_version < 15.0:
923        UNIMPLEMENTED_XFAILLIST.update({
924            'quantile': None,
925            'nanquantile': None,
926        })
927
928    UNDEFINED_XFAILLIST = {
929        # Top 60 operators
930        # topk fails with duplicate indices
931        'topk': [torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
932
933        # Failures due to random output that they generate using
934        # Philox engine causing mismatch with CPU results
935        'multinomial': [torch.float16, torch.float32],  # random results
936        'uniform': [torch.float16, torch.float32],
937        'rand_like': [torch.float16, torch.float32],
938        'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
939        'randn_like': [torch.float16, torch.float32],
940        'bernoulli': [torch.float16, torch.float32],
941        'exponential': [torch.float16, torch.float32],
942        'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32],
943        'normal': [torch.float16, torch.float32, torch.float16, torch.float32],
944        'normalin_place': [torch.float16, torch.float32],
945        'normalnumber_mean': [torch.float16, torch.float32],
946        'nn.functional.alpha_dropout': [torch.float16, torch.float32],
947        'nn.functional.dropout': [torch.float16, torch.float32],
948        'nn.functional.dropout2d': [torch.float16, torch.float32],
949        'nn.functional.dropout3d': [torch.float16, torch.float32],
950        # See https://github.com/pytorch/pytorch/issues/111479
951        'nn.functional.multi_head_attention_forward': [torch.float32, torch.float16],
952
953        # duplicate indices are used in the testcase - undefined behaviour
954        'index_put': None,
955        # zero to negative integer powers are undefined
956        '__rpow__': [torch.int8, torch.int16, torch.int32, torch.int64],
957        'resize_': [torch.float16, torch.float32],
958        'resize_as_': [torch.float16, torch.float32],
959
960        # CPU Errors:
961        'addr': [torch.bool, torch.int16, torch.int32,
962                 torch.int64, torch.uint8, torch.int8],  # "addmv_impl_cpu" not implemented for 'Half'
963        'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
964                                    torch.int32, torch.int64, torch.uint8, torch.int8],  # cpu result off, showing random values
965        'as_strided_partial_views': [torch.bool, torch.float16, torch.float32, torch.int16,
966                                     torch.int32, torch.int64, torch.uint8, torch.int8],  # cpu result off, showing random values
967
968        # random results
969        # mps vs cpu:
970        # Mismatched elements: 40 / 96 (41.7%)
971        # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
972        # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
973        # cuda(2.0.0.dev20230301+cu117) vs cpu:
974        # Mismatched elements: 56 / 96 (58.3%)
975        # Greatest absolute difference: 17.892311096191406 at index (1, 0, 2) (up to 1e-05 allowed)
976        # Greatest relative difference: inf at index (1, 0, 0) (up to 1.3e-06 allowed)
977        'nn.functional.scaled_dot_product_attention': [torch.float32, torch.float16],
978
979        # float output for float16 input on MPS
980        'logit': [torch.float16],
981    }
982
983    ON_MPS_XFAILLIST = {
984        # Failures due to lack of implementation of downstream functions on MPS backend
985        # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
986        'linalg.matrix_rank': None,
987    }
988
989    EMPTY_OPS_SKIPLIST = {
990        # Fill tensors with uninitialized data, causing mismatch with CPU.
991        # They occasionally match, thus skipping them.
992        # See https://github.com/pytorch/pytorch/issues/100175
993        'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
994        'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16,
995                              torch.int32, torch.int64, torch.uint8, torch.int8],
996        'empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
997        # CPU: empty is returning all 0's and there is a mismatch with MPS
998        # allocation (MacOS 13). According to
999        # https://pytorch.org/docs/2.0/generated/torch.empty.html
1000        'empty': [torch.bool, torch.float16, torch.float32, torch.int16,
1001                  torch.int32, torch.int64, torch.uint8, torch.int8],
1002        'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
1003        'empty_permuted': [torch.bool, torch.float16, torch.float32, torch.int16,
1004                           torch.int32, torch.int64, torch.uint8, torch.int8],
1005    }
1006
1007    SKIPLIST = {
1008        # Unsupported
1009        # input types 'tensor<1x3x9x9xf16>' and 'tensor<1xf32>' are not broadcast compatible
1010        'nn.functional.avg_pool2d': [torch.float16],
1011
1012        # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16
1013        'nn.functional.conv3d': None,
1014    }
1015
1016    def addDecorator(op, d) -> None:
1017        op.decorators = list(op.decorators) if op.decorators is not None else []
1018        op.decorators.append(d)
1019
1020    for op in ops:
1021        key = op.name + op.variant_test_name
1022        if key in EMPTY_OPS_SKIPLIST:
1023            addDecorator(op, DecorateInfo(
1024                         unittest.skip("Skipping empty ops."),
1025                         dtypes=EMPTY_OPS_SKIPLIST[key]))
1026        if key in SKIPLIST:
1027            addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
1028        for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
1029            if key in xfaillist:
1030                addDecorator(op, DecorateInfo(
1031                             unittest.expectedFailure,
1032                             dtypes=xfaillist[key]))
1033
1034        if key in MACOS_BEFORE_14_4_XFAILLIST and (product_version < 14.4):
1035            addDecorator(op, DecorateInfo(
1036                         unittest.expectedFailure,
1037                         dtypes=MACOS_BEFORE_14_4_XFAILLIST[key]))
1038
1039        if key in MACOS_BEFORE_13_3_XFAILLIST and (torch.backends.mps.is_macos13_or_newer() and product_version < 13.3):
1040            addDecorator(op, DecorateInfo(
1041                         unittest.expectedFailure,
1042                         dtypes=MACOS_BEFORE_13_3_XFAILLIST[key]))
1043
1044        if key in MACOS_AFTER_13_1_XFAILLIST and torch.backends.mps.is_macos13_or_newer(2):
1045            addDecorator(op, DecorateInfo(
1046                         unittest.expectedFailure,
1047                         dtypes=MACOS_AFTER_13_1_XFAILLIST[key]))
1048
1049        if key in MACOS_13_3_XFAILLIST and (product_version >= 13.3):
1050            addDecorator(op, DecorateInfo(
1051                         unittest.expectedFailure,
1052                         dtypes=MACOS_13_3_XFAILLIST[key]))
1053
1054        if key in MACOS_12_3_XFAILLIST and (not torch.backends.mps.is_macos13_or_newer()):
1055            addDecorator(op, DecorateInfo(
1056                         unittest.expectedFailure,
1057                         dtypes=MACOS_12_3_XFAILLIST[key]))
1058
1059        # If ops is not supported for complex types, expect it to fail
1060        if key not in SUPPORTED_COMPLEX_OPS and (key not in AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS or product_version < 14.0):
1061            addDecorator(op, DecorateInfo(unittest.expectedFailure, dtypes=[torch.complex32, torch.complex64]))
1062
1063        yield op
1064
1065def mps_ops_error_inputs_modifier(ops):
1066    # Error input samples do not take a dtype argument.
1067    XFAILLIST = {
1068        # Exceptions are not raised
1069        '__rmod__',
1070        '__rsub__',
1071        '__rpow__',
1072        'bernoulli',
1073        'clamp_max',
1074        'clamp_min',
1075        'masked_scatter',
1076
1077        # unsupported float64 dtype
1078        'cat',
1079        'complex',
1080        'multinomial',
1081        'nn.functional.conv1d',
1082        'nn.functional.conv2d',
1083        'nn.functional.conv3d',
1084        'gather',
1085        'scatter',
1086        'scatter_add',
1087
1088        # unsupported complex dtypes
1089        'masked_fill',
1090
1091        # MPS does not support tensor dimensions > 16
1092        'amax',
1093        'amin',
1094        'aminmax',
1095
1096        # memory overlapping checks
1097        'index_select',
1098
1099        # unimplemented
1100        'logcumsumexp',
1101    }
1102
1103    def addDecorator(op, d) -> None:
1104        op.decorators = list(op.decorators) if op.decorators is not None else []
1105        op.decorators.append(d)
1106
1107    for op in ops:
1108        if op.error_inputs_func is None:
1109            continue
1110        key = op.name + op.variant_test_name
1111        if key in XFAILLIST:
1112            addDecorator(op, DecorateInfo(unittest.expectedFailure))
1113        yield op
1114
1115# Same logic as test_cuda.py
1116if not torch.backends.mps.is_available():
1117    print('MPS not available, skipping tests', file=sys.stderr)
1118    TestCase = NoTest  # noqa: F811
1119    NNTestCase = NoTest  # noqa: F811
1120
1121product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]) or -1)
1122total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"]))
1123
1124# Determine whether to enable MPS memory leak check (uses same code as CUDA).
1125TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
1126
1127def skipMPSMemoryLeakCheckIf(condition):
1128    def dec(fn):
1129        if getattr(fn, '_do_mps_memory_leak_check', True):
1130            fn._do_mps_memory_leak_check = not condition
1131        return fn
1132    return dec
1133
1134class MpsMemoryLeakCheck:
1135    def __init__(self, testcase, name=None):
1136        self.name = testcase.id() if name is None else name
1137        self.testcase = testcase
1138
1139    def __enter__(self):
1140        # Performs a gc if required (required if any memory is held)
1141        caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1142        if caching_allocator_mem_allocated > 0:
1143            gc.collect()
1144            torch.mps.empty_cache()
1145
1146        # Acquires caching allocator and driver statistics before the test is run
1147        self.caching_allocator_before = torch.mps.current_allocated_memory()
1148        self.driver_before = torch.mps.driver_allocated_memory()
1149
1150    def __exit__(self, exec_type, exec_value, traceback):
1151        # Don't check for leaks if an exception was thrown
1152        if exec_type is not None:
1153            return
1154        # Compares caching allocator before/after statistics
1155        # An increase in allocated memory is a discrepancy indicating a possible memory leak
1156        discrepancy_detected = False
1157        caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1158        if caching_allocator_mem_allocated > self.caching_allocator_before:
1159            discrepancy_detected = True
1160
1161        # Short-circuits if no discrepancy detected
1162        if not discrepancy_detected:
1163            return
1164        # Validates the discrepancy persists after garbage collection and
1165        # is confirmed by the driver API
1166        gc.collect()
1167        torch.mps.empty_cache()
1168
1169        discrepancy_detected = True
1170        # Query memory multiple items to ensure leak was not transient
1171        for n in range(3):
1172            caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
1173            driver_mem_allocated = torch.mps.driver_allocated_memory()
1174
1175            caching_allocator_discrepancy = False
1176            driver_discrepancy = False
1177
1178            if caching_allocator_mem_allocated > self.caching_allocator_before:
1179                caching_allocator_discrepancy = True
1180
1181            if driver_mem_allocated > self.driver_before:
1182                driver_discrepancy = True
1183
1184            if not (caching_allocator_discrepancy or driver_discrepancy):
1185                # Leak was false positive, exit loop
1186                discrepancy_detected = False
1187                break
1188
1189        if caching_allocator_discrepancy and not driver_discrepancy:
1190            # Just raises a warning if the leak is not validated by the driver API
1191            msg = ("MPS caching allocator reports a memory leak not "
1192                   f"verified by the driver API in {self.name}! "
1193                   f"Caching allocator allocated memory was {self.caching_allocator_before} "
1194                   f"and is now reported as {caching_allocator_mem_allocated}. "
1195                   f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
1196            warnings.warn(msg)
1197        elif caching_allocator_discrepancy and driver_discrepancy:
1198            # A caching allocator discrepancy validated by the driver API is a failure
1199            msg = (f"MPS driver API confirmed a leak in {self.name}! "
1200                   f"Caching allocator allocated memory was {self.caching_allocator_before} "
1201                   f"and is now reported as {caching_allocator_mem_allocated}. "
1202                   f"MPS driver allocated memory was {self.driver_before} and is now {driver_mem_allocated}.")
1203
1204            raise RuntimeError(msg)
1205
1206class TestAutocastMPS(TestCase):
1207
1208    def test_matmul_autocast(self):
1209        autocast_tensor_A = torch.rand((8, 8), device="mps")
1210        autocast_tensor_B = torch.rand((8, 8), device="mps")
1211        tensor_A = autocast_tensor_A.clone().detach()
1212        tensor_B = autocast_tensor_B.clone().detach()
1213        autocast_output_tensor = torch.empty(8, 8)
1214        output_tensor = autocast_output_tensor.clone().detach()
1215
1216        with torch.autocast(device_type="mps"):
1217            autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_tensor_B)
1218            autocast_output_tensor = torch.mm(autocast_tensor_A, autocast_output_tensor)
1219
1220        output_tensor = torch.mm(tensor_A, tensor_B)
1221        output_tensor = torch.mm(tensor_A, output_tensor)
1222
1223        self.assertEqual(autocast_output_tensor.dtype, torch.float16, "Autocast output tensor was not expected type float16")
1224        self.assertEqual(autocast_output_tensor,
1225                         output_tensor.to(torch.float16),
1226                         f"Autocast & non-autocast tensors did not match, \
1227                         got:\n{autocast_output_tensor} \n{output_tensor.to(torch.float16)}")
1228
1229# Expand TestCase class with Memory Leak Detection on MPS device
1230class TestCaseMPS(TestCase):
1231    _do_mps_memory_leak_check = True
1232
1233    def __init__(self, method_name='runTest'):
1234        super().__init__(method_name)
1235        test_method = getattr(self, method_name, None)
1236        if test_method is not None:
1237            # Wraps the tested method if we should do MPS memory check.
1238            if TEST_MPS_MEM_LEAK_CHECK:
1239                if self._do_mps_memory_leak_check:
1240                    self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
1241
1242    def assertLeaksNoMpsTensors(self, name=None):
1243        name = self.id() if name is None else name
1244        return MpsMemoryLeakCheck(self, name)
1245
1246    def wrap_with_mps_policy(self, method_name, policy):
1247        test_method = getattr(self, method_name)
1248        setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
1249
1250    # checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
1251    def wrap_with_mps_memory_check(self, method):
1252        return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
1253
1254class TestMemoryLeak(TestCaseMPS):
1255    def test_mps_memory_leak_detection(self):
1256        l = []
1257
1258        @self.wrap_with_mps_memory_check
1259        def no_leak():
1260            pass
1261
1262        # Trigger an intentional memory leak
1263        @self.wrap_with_mps_memory_check
1264        def leak_gpu0():
1265            # increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
1266            l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
1267
1268        no_leak()
1269
1270        # check if a runtime error for memory leak was emitted which would
1271        # confirm whether memory leak detection worked successfully or not.
1272        with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
1273            leak_gpu0()
1274
1275    def test_copy_cast_no_leak(self):
1276
1277        def step(x):
1278            x = x.to(device='cpu', dtype=torch.float32)
1279            x = x.to(device='mps', dtype=torch.float16)
1280
1281        a = torch.randn(128, 128, device='mps', dtype=torch.float16)
1282        # Warm up / prebuild MPS shaders (otherwise check fails on 13.2)
1283        step(a)
1284        torch.mps.empty_cache()
1285        driver_before = torch.mps.driver_allocated_memory()
1286        step(a)
1287        torch.mps.empty_cache()
1288        driver_after = torch.mps.driver_allocated_memory()
1289        self.assertEqual(driver_before, driver_after, f"Detected {driver_after-driver_before} bytes leak of GPU memory")
1290
1291
1292class TestPixelShuffle(TestCaseMPS):
1293    def test_pixel_shuffle_unshuffle(self):
1294        def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
1295                                                 upscale_factor=None, is_contiguous=True):
1296
1297            def generate_input():
1298                # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
1299                channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
1300                height = random.randint(5, 10)
1301                width = random.randint(5, 10)
1302
1303                if num_input_dims == 1:
1304                    input = torch.rand(channels, requires_grad=True, device='mps')
1305                    assert is_contiguous
1306                elif num_input_dims == 2:
1307                    input = torch.rand(width, height, requires_grad=True, device='mps').T
1308                    if is_contiguous:
1309                        input = input.contiguous()
1310                else:
1311                    batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1312                    input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
1313                    input = input.transpose(-1, -2)
1314                    if is_contiguous:
1315                        input = input.contiguous()
1316
1317                if not is_contiguous and len(input.reshape(-1)) > 0:
1318                    assert not input.is_contiguous()
1319
1320                input = input.detach().clone()
1321                input.requires_grad = True
1322                return input
1323
1324            # Function to imperatively ensure pixels are shuffled to the correct locations.
1325            # Used to validate the batch operations in pixel_shuffle.
1326            def _verify_pixel_shuffle(input, output, upscale_factor):
1327                for c in range(output.size(-3)):
1328                    for h in range(output.size(-2)):
1329                        for w in range(output.size(-1)):
1330                            height_idx = h // upscale_factor
1331                            weight_idx = w // upscale_factor
1332                            channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
1333                                          (c * upscale_factor ** 2)
1334                            self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
1335
1336            upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
1337            input = generate_input()
1338
1339            ps = nn.PixelShuffle(upscale_factor)
1340            pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
1341
1342            if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
1343                output = ps(input)
1344                _verify_pixel_shuffle(input, output, upscale_factor)
1345                output.backward(output.data)
1346                self.assertEqual(input.data, input.grad.data)
1347
1348                # Ensure unshuffle properly inverts shuffle.
1349                unshuffle_output = pus(output)
1350                self.assertEqual(input, unshuffle_output)
1351            else:
1352                self.assertRaises(RuntimeError, lambda: ps(input))
1353
1354        def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
1355                                                    downscale_factor=None):
1356            downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
1357            channels = random.randint(1, 4)
1358            # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
1359            height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
1360            # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
1361            width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
1362
1363            if num_input_dims == 1:
1364                input = torch.rand(channels, requires_grad=True, device='mps')
1365            elif num_input_dims == 2:
1366                input = torch.rand(height, width, requires_grad=True, device='mps')
1367            else:
1368                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
1369                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
1370
1371            pus = nn.PixelUnshuffle(downscale_factor)
1372            self.assertRaises(RuntimeError, lambda: pus(input))
1373
1374        def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
1375            # For 1D - 2D, this is an error case.
1376            # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
1377            is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
1378            for is_contiguous in is_contiguous_check:
1379                _test_pixel_shuffle_unshuffle_helper(
1380                    num_input_dims=num_input_dims, is_contiguous=is_contiguous
1381                )
1382                _test_pixel_shuffle_unshuffle_helper(
1383                    num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
1384                )
1385                _test_pixel_shuffle_unshuffle_helper(
1386                    num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
1387                )
1388                _test_pixel_shuffle_unshuffle_helper(
1389                    num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
1390                )
1391
1392                # Error cases for pixel_unshuffle.
1393            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
1394            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
1395            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
1396            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
1397
1398        def test_pixel_shuffle_unshuffle_1D():
1399            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
1400
1401        def test_pixel_shuffle_unshuffle_2D():
1402            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
1403
1404        def test_pixel_shuffle_unshuffle_3D():
1405            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
1406
1407        def test_pixel_shuffle_unshuffle_4D():
1408            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
1409
1410        def test_pixel_shuffle_unshuffle_5D():
1411            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
1412
1413        test_pixel_shuffle_unshuffle_1D()
1414        test_pixel_shuffle_unshuffle_2D()
1415        test_pixel_shuffle_unshuffle_3D()
1416        test_pixel_shuffle_unshuffle_4D()
1417        test_pixel_shuffle_unshuffle_5D()
1418
1419class MPSReluTest(TestCaseMPS):
1420    def _npRelu(self, np_features):
1421        return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
1422
1423    def testNpRelu(self):
1424        torch.testing.assert_close(
1425            np.array([[0., 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
1426            self._npRelu(
1427                np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1428                                                         0.9]])))
1429
1430    def _testRelu(self, np_features, device):
1431        np_relu = self._npRelu(np_features)
1432        # Convert the numpy array to a PyTorch Tensor,
1433        # and move the Tensor to the CPU/GPU based on the "device" parameter
1434        py_tensor = torch.from_numpy(np_features).to(device)
1435        py_relu = torch.nn.ReLU(inplace=False)(py_tensor)
1436        py_relu_cpu = py_relu.to("cpu")
1437
1438        self.assertEqual(np_relu, py_relu_cpu)
1439
1440    def _testReluInPlace(self, np_features, device):
1441        np_relu = self._npRelu(np_features)
1442        # Convert the numpy array to a PyTorch Tensor,
1443        # and move the Tensor to the CPU/GPU based on the "device" parameter
1444        py_tensor = torch.from_numpy(np_features).to(device)
1445        py_relu = torch.nn.ReLU(inplace=True)(py_tensor)
1446        py_relu_cpu = py_relu.to("cpu")
1447
1448        self.assertEqual(np_relu, py_relu_cpu)
1449        # Inplace Relu modifies the initial input and it should match the output of Relu
1450        self.assertEqual(np_relu, py_tensor.to("cpu"))
1451
1452    def testNumbersCPU(self):
1453        for t in [np.int32]:
1454            # Force execution on CPU even if a GPU kernel is available for the type.
1455            self._testRelu(
1456                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1457                device="cpu")
1458            self._testReluInPlace(
1459                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1460                device="cpu")
1461
1462    def testNumbersGPU(self):
1463        for t in [np.float16, np.float32]:
1464            self._testRelu(
1465                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1466                device="mps")
1467            self._testReluInPlace(
1468                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1469                device="mps")
1470            self._testRelu(np.array([]).astype(t), device="mps")
1471            self._testReluInPlace(np.array([]).astype(t), device="mps")
1472
1473class MatmulTest(TestCaseMPS):
1474    def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
1475        if expand_tensor_1_shape:
1476            tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
1477        else:
1478            tensor1_mps = torch.randn(shape_tensor_1, device="mps")
1479
1480        if expand_tensor_2_shape:
1481            tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
1482        else:
1483            tensor2_mps = torch.randn(shape_tensor_2, device="mps")
1484
1485        tensor1_cpu = tensor1_mps.to("cpu")
1486        tensor2_cpu = tensor2_mps.to("cpu")
1487
1488        matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
1489        matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
1490
1491        self.assertEqual(matmul_cpu, matmul_mps.to("cpu"))
1492
1493    def test_vector_x_vector(self):
1494        # uses `dot`
1495        self._helper(3, 3)
1496
1497    def test_matrix_x_vector(self):
1498        # uses `addmv`
1499        self._helper((3, 4), 4)
1500
1501    def test_batched_matrix_x_broadcasted_vector(self):
1502        self._helper((10, 3, 4), 4)
1503
1504    def test_batched_matrix_x_batched_matrix(self):
1505        # uses `bmm.out`
1506        self._helper((10, 3, 4), (10, 4, 5))
1507
1508    def test_batched_matrix_x_broadcasted_matrix(self):
1509        self._helper((10, 3, 4), (4, 5))
1510
1511
1512class MPSLeakyReluTest(TestCaseMPS):
1513    def _npLeakyRelu(self, np_features, negative_slope=0.1):
1514        return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
1515
1516    def testNpLeakyRelu(self):
1517        torch.testing.assert_close(
1518            np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
1519                      [0.1, -0.03, 0.5, -0.07, 0.9]]),
1520            self._npLeakyRelu(
1521                np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
1522                                                         0.9]]),
1523                negative_slope=0.1))
1524
1525    def _testLeakyRelu(self, shape, dtype, negative_slope, contiguous):
1526        cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
1527        mps_x = cpu_x.detach().clone().to('mps')
1528
1529        if not contiguous and not (0 in shape or len(shape) < 2):
1530            # Tranposing will make the tensor non-contiguous
1531            cpu_x = cpu_x.transpose(0, 1)
1532            mps_x = mps_x.transpose(0, 1)
1533            assert not mps_x.is_contiguous()
1534
1535        cpu_x.requires_grad_()
1536        mps_x.requires_grad_()
1537
1538        relu_op = torch.nn.LeakyReLU(negative_slope)
1539
1540        cpu_leaky_relu = relu_op(cpu_x)
1541        mps_leaky_relu = relu_op(mps_x)
1542        torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
1543
1544        # test backward pass
1545
1546        cpu_grad = torch.ones_like(cpu_leaky_relu)
1547        mps_grad = cpu_grad.to('mps')
1548
1549        mps_leaky_relu.backward(gradient=mps_grad)
1550        cpu_leaky_relu.backward(gradient=cpu_grad)
1551
1552        assert cpu_x.grad is not None  # Check that the grad is well-populated
1553        self.assertEqual(cpu_x.grad, mps_x.grad)
1554
1555    def testNumbersCPU(self):
1556        for t in [torch.float, torch.half]:
1557            for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
1558                for contiguous in [True, False]:
1559                    self._testLeakyRelu(shape,
1560                                        dtype=t,
1561                                        negative_slope=0.2,
1562                                        contiguous=contiguous)
1563
1564class TestAvgPool(TestCaseMPS):
1565    def _sum_pool2d(self, x, kernel_size):
1566        windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
1567        return torch.sum(windows, dim=1)
1568
1569    def _sum_pool3d(self, x, kernel_size):
1570        # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
1571        h = kernel_size[0]
1572        splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
1573        # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
1574        splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
1575        joined_x = torch.cat(splited_x)
1576        return joined_x.view(1, joined_x.numel())
1577
1578    def _avg_pool2d(self, x, kernel_size):
1579        size = reduce(operator.mul, kernel_size)  # noqa: F821
1580        return self._sum_pool2d(x, kernel_size) / size
1581
1582    def _avg_pool3d(self, x, kernel_size):
1583        size = reduce(operator.mul, kernel_size)  # noqa: F821
1584        return self._sum_pool3d(x, kernel_size) / size
1585
1586    def test_avg_pool2d_with_zero_divisor(self):
1587        self.assertRaisesRegex(RuntimeError, "divisor must be not zero",
1588                               lambda: F.avg_pool2d(torch.zeros(3, 3, 3), (2, 2), divisor_override=0))
1589
1590    def test_doubletensor_avg_pool2d_with_divisor(self):
1591        n, m = 3, 3
1592        input = torch.rand(1, 1, n, m)
1593        for i in range(1, n + 1):
1594            for j in range(1, m + 1):
1595                for divisor in [1, 7, i * j]:
1596                    actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
1597                    actual = actual.view(1, actual.numel())
1598                    expected = self._sum_pool2d(input, (i, j)) / divisor
1599                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)
1600
1601    def test_avg_pool2d_ceil_mode(self):
1602        # Regression test for gh-36977
1603        x = 10 * torch.randn((1, 16, 4, 4))
1604        y = torch.nn.functional.avg_pool2d(
1605            x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1606            padding=(0, 1), stride=2)
1607        self.assertFalse(torch.isnan(y).any())
1608        y = torch.nn.functional.avg_pool2d(
1609            x.to('mps'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
1610            padding=(0, 1), stride=2)
1611        self.assertFalse(torch.isnan(y).any())
1612
1613
1614class TestMPS(TestCaseMPS):
1615    def test_exp(self, device="mps", dtype=torch.float):
1616        for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
1617            b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
1618            a = torch.tensor(v, dtype=dtype, device="mps") * b
1619            self.compare_with_numpy(torch.exp, np.exp, a)
1620
1621    def test_conv_raises_error(self, device='mps', dtype=torch.float):
1622        conv = nn.Conv1d(1, 65537, 3, padding=1).to('mps')
1623
1624        x = torch.ones([1, 1, 3])
1625        with self.assertRaises(NotImplementedError):
1626            y = conv(x.to("mps"))
1627
1628    def test_triu_inf(self, device="mps", dtype=torch.float):
1629        for diag in [-1, 0, 1]:
1630            mask = torch.full((3, 6, 6), float("-inf"))
1631            mask_mps = mask.clone().detach().to('mps')
1632            cpu_ref = torch.triu(mask, diagonal=diag)
1633            mps_out = torch.triu(mask_mps, diagonal=diag)
1634            self.assertEqual(cpu_ref, mps_out)
1635
1636    def test_exp1(self, device="mps", dtype=torch.float):
1637        input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
1638        output = torch.exp(input)
1639        output_cpu = torch.exp(input.cpu())
1640        # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
1641        # Mismatched elements: 3 / 4 (75.0%)
1642        # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
1643        # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
1644        self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
1645
1646    def test_exp_strided_output(self):
1647        x = torch.rand((256, 10), device='mps')
1648        x_cpu = x.to("cpu")
1649
1650        x = x.permute(1, 0)
1651        x_cpu = x_cpu.permute(1, 0)
1652
1653        res = x.exp()
1654        res_cpu = x_cpu.exp()
1655        self.assertEqual(res, res_cpu)
1656
1657    def _testLeakyRelu(self, np_features, negative_slope, device):
1658        cpu_x = torch.from_numpy(np_features).requires_grad_()
1659        mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()
1660        relu_op = torch.nn.LeakyReLU(negative_slope)
1661
1662        cpu_leaky_relu = relu_op(cpu_x)
1663        mps_leaky_relu = relu_op(mps_x)
1664        torch.testing.assert_close(cpu_leaky_relu, mps_leaky_relu.to('cpu'))
1665
1666        # test backward pass
1667        cpu_grad = torch.ones_like(cpu_leaky_relu)
1668        mps_grad = cpu_grad.to('mps')
1669        cpu_leaky_relu.backward(gradient=cpu_grad)
1670        mps_leaky_relu.backward(gradient=mps_grad)
1671        torch.testing.assert_close(cpu_x.grad, mps_x.grad.to('cpu'))
1672
1673    def testNumbersGPU(self):
1674        for t in [np.float32]:
1675            self._testLeakyRelu(
1676                np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
1677                negative_slope=0.1,
1678                device="mps")
1679
1680    def test_fill(self):
1681
1682        def helper(val, shape, dtype):
1683            tensor = torch.zeros(shape, device='mps', dtype=dtype)
1684            tensor_mps = tensor.fill_(val)
1685
1686            tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
1687            tensor_cpu = tensor_0.fill_(val)
1688
1689            self.assertEqual(tensor_mps, tensor_cpu)
1690
1691        helper(0, [1024], torch.float32)
1692        helper(0.2, [2, 3], torch.float32)
1693        helper(0.2 + 0.5j, [2, 3], torch.complex64)
1694
1695    def test_fill_storage_offset(self):
1696        shape = [2, 10]
1697        val = 0.2
1698        tensor = torch.ones(shape, device="mps")
1699        tensor_mps = tensor[:][1].fill_(val)
1700        tensor_0 = torch.ones(shape, device="cpu")
1701        tensor_cpu = tensor_0[:][1].fill_(val)
1702
1703        self.assertEqual(tensor_mps, tensor_cpu)
1704        self.assertEqual(tensor, tensor_0)
1705
1706        shape = [1, 10]
1707        val = 0.0
1708        tensor = torch.ones(shape, device="mps")
1709        val_tensor_mps = torch.tensor(val, device="mps")
1710        tensor_mps = tensor[:, 9].fill_(val_tensor_mps)
1711        # Regression test for https://github.com/pytorch/pytorch/issues/114692
1712        tensor[:, 5].fill_(val_tensor_mps)
1713        tensor_0 = torch.ones(shape, device="cpu")
1714        val_tensor_cpu = torch.tensor(val, device="cpu")
1715        tensor_cpu = tensor_0[:, 9].fill_(val_tensor_cpu)
1716        tensor_0[:, 5].fill_(val_tensor_cpu)
1717
1718        self.assertEqual(tensor_mps.to(device="cpu"), tensor_cpu)
1719        self.assertEqual(tensor.to(device="cpu"), tensor_0)
1720
1721    def test_cdist_large(self, device="mps"):
1722        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1723            x = torch.randn(100, 10, device=device)
1724            y = torch.randn(100, 10, device=device)
1725            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1726            expected = self._brute_cdist(x, y, p=2)
1727            self.assertEqual(expected, actual)
1728
1729    def test_cdist_large_batch(self, device="mps"):
1730        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1731            x = torch.randn(4, 3, 100, 10, device=device)
1732            y = torch.randn(4, 3, 100, 10, device=device)
1733            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1734            expected = self._brute_cdist(x, y, p=2)
1735            self.assertEqual(expected, actual)
1736
1737    def test_cdist_non_contiguous(self, device="mps"):
1738        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1739            x = torch.randn(5, 7, device=device).mT
1740            y = torch.randn(5, 3, device=device).mT
1741            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1742            expected = self._brute_cdist(x, y, p=2)
1743            self.assertFalse(x.is_contiguous())
1744            self.assertFalse(y.is_contiguous())
1745            self.assertEqual(expected, actual)
1746
1747            x = torch.randn(7, 5, device=device)
1748            y = torch.randn(5, 3, device=device).t()
1749            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1750            expected = self._brute_cdist(x, y, p=2)
1751            self.assertTrue(x.is_contiguous())
1752            self.assertFalse(y.is_contiguous())
1753            self.assertEqual(expected, actual)
1754
1755            x = torch.randn(5, 7, device=device).t()
1756            y = torch.randn(3, 5, device=device)
1757            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1758            expected = self._brute_cdist(x, y, p=2)
1759            self.assertFalse(x.is_contiguous())
1760            self.assertTrue(y.is_contiguous())
1761            self.assertEqual(expected, actual)
1762
1763    def test_cdist_non_contiguous_batch(self, device="mps"):
1764        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1765            x = torch.randn(4, 3, 2, 5, 7, device=device).mT
1766            y = torch.randn(4, 3, 2, 5, 3, device=device).mT
1767            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1768            expected = self._brute_cdist(x, y, p=2)
1769            self.assertFalse(x.is_contiguous())
1770            self.assertFalse(y.is_contiguous())
1771            self.assertEqual(expected, actual)
1772
1773            x = torch.randn(7, 2, 7, 5, device=device)
1774            y = torch.randn(7, 2, 5, 3, device=device).mT
1775            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1776            expected = self._brute_cdist(x, y, p=2)
1777            self.assertTrue(x.is_contiguous())
1778            self.assertFalse(y.is_contiguous())
1779            self.assertEqual(expected, actual)
1780
1781            x = torch.randn(4, 5, 7, device=device).mT
1782            y = torch.randn(4, 3, 5, device=device)
1783            actual = torch.cdist(x, y, p=2, compute_mode=cm)
1784            expected = self._brute_cdist(x, y, p=2)
1785            self.assertFalse(x.is_contiguous())
1786            self.assertTrue(y.is_contiguous())
1787            self.assertEqual(expected, actual)
1788
1789    def test_cdist_euclidean_large(self, device="mps"):
1790        def _test_euclidean_large_cdist(sizex, sizey=None):
1791            if sizey is None:
1792                sizey = sizex
1793            x = torch.randn(sizex, device=device, dtype=torch.float)
1794            y = torch.randn(sizey, device=device, dtype=torch.float)
1795            eps = 1e-6
1796            # to avoid extremum
1797            x = x - (((x - y) < eps).float() * 2 * eps)
1798            x.requires_grad = True
1799            y.requires_grad = True
1800            dist = torch.cdist(x, y, p=2)
1801            # Do a backward pass to check that it is valid for large
1802            # matrices
1803            loss = dist.sum()
1804            loss.backward()
1805
1806        _test_euclidean_large_cdist((2000, 5))
1807
1808    def test_cdist_same_inputs(self, device="mps"):
1809        # Test to detect issues in cdist gradient calculation
1810        # When the distances are 0
1811        sizex = (1, 27, 32)
1812        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
1813            x = torch.randn(sizex, device=device, dtype=torch.float)
1814            dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
1815            y = x.clone()
1816            eps = 1e-6
1817            x.requires_grad = True
1818            d = torch.cdist(x, y)
1819            d.backward(dist_grad)
1820            # Check that the backward passs does not contain invalid
1821            # values such as nan or inf
1822            assert torch.isfinite(x.grad).all()
1823
1824
1825    def _brute_cdist(self, x, y, p=2):
1826        r1 = x.shape[-2]
1827        r2 = y.shape[-2]
1828        if r1 == 0 or r2 == 0:
1829            return torch.empty(r1, r2, device=x.device)
1830        return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
1831
1832    def test_cdist_norm(self, device="mps"):
1833        for r1 in [3, 4]:
1834            for m in [2, 3]:
1835                for r2 in [4, 6]:
1836                    for p in [0, 1, 1.5, 2.5, float('inf')]:
1837                        x = torch.randn(r1, m, device=device)
1838                        y = torch.randn(r2, m, device=device)
1839                        if p == 2:
1840                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1841                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
1842                                expected = self._brute_cdist(x, y, p=2)
1843                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
1844                        else:
1845                            actual = torch.cdist(x, y, p=p)
1846                            expected = self._brute_cdist(x, y, p=p)
1847                            self.assertEqual(expected, actual)
1848
1849    def test_cdist_norm_batch(self, device="mps"):
1850        for r1 in [3, 4]:
1851            for m in [2, 3]:
1852                for r2 in [4, 6]:
1853                    for p in [0, 3, 1.5, 2.5, float('inf')]:
1854                        x = torch.randn(2, 3, 6, r1, m, device=device)
1855                        y = torch.randn(2, 3, 6, r2, m, device=device)
1856                        if p == 2:
1857                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1858                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
1859                                expected = self._brute_cdist(x, y, p=2)
1860                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
1861                        else:
1862                            actual = torch.cdist(x, y, p=p)
1863                            expected = self._brute_cdist(x, y, p=p)
1864                            self.assertEqual(expected, actual)
1865
1866    def test_mm(self):
1867        B = torch.ones(5, 6).to("mps")
1868        C = torch.ones(6, 5).to("mps")
1869        D = torch.mm(B, C).cpu()
1870        torch.testing.assert_close(D, torch.full((5, 5), 6.0))
1871
1872    def test_linalg_cross(self):
1873        def helper(dtype):
1874            device = "mps"
1875            if dtype is torch.int32 or dtype is torch.int64:
1876                x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1877                y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
1878            else:
1879                x = torch.rand(100, 3, 100, dtype=dtype, device=device)
1880                y = torch.rand(100, 3, 100, dtype=dtype, device=device)
1881            x_cpu = x.to("cpu")
1882            y_cpu = y.to("cpu")
1883            res1 = torch.linalg.cross(x, y, dim=1)
1884            res2 = torch.tensor((), dtype=dtype, device=device)
1885            res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1886            res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1887            torch.linalg.cross(x, y, dim=1, out=res2)
1888            torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1889            self.assertEqual(res1, res2)
1890            self.assertEqual(res1, res1_cpu)
1891            self.assertEqual(res2, res2_cpu)
1892
1893            # test for broadcastable inputs
1894            if dtype is torch.int32 or dtype is torch.int64:
1895                x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
1896                y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
1897            else:
1898                x = torch.rand(1, 3, 2, dtype=dtype, device=device)
1899                y = torch.rand(4, 3, 1, dtype=dtype, device=device)
1900            x_cpu = x.to("cpu")
1901            y_cpu = y.to("cpu")
1902            res1 = torch.linalg.cross(x, y, dim=1)
1903            res2 = torch.tensor((), dtype=dtype, device=device)
1904            res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
1905            res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
1906            torch.linalg.cross(x, y, dim=1, out=res2)
1907            torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
1908            self.assertEqual(res1, res2)
1909            self.assertEqual(res1, res1_cpu)
1910            self.assertEqual(res2, res2_cpu)
1911        [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
1912
1913    def test_cross(self):
1914        a = torch.randn(4, 3, device="mps")
1915        b = torch.randn(4, 3, device="mps")
1916        a_cpu = a.to("cpu")
1917        b_cpu = b.to("cpu")
1918        res = torch.cross(a, b, dim=1)
1919        res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
1920        self.assertEqual(res, res_cpu)
1921
1922    def test_addmm(self):
1923        A = torch.ones(5, 5).to("mps")
1924        B = torch.ones(5, 6).to("mps")
1925        C = torch.ones(6, 5).to("mps")
1926        D = torch.addmm(A, B, C).to("cpu")
1927        torch.testing.assert_close(D, torch.full((5, 5), 7.0))
1928
1929    def test_bmm(self):
1930        batch1_cpu = torch.randn(10, 3, 4)
1931        batch2_cpu = torch.randn(10, 4, 5)
1932
1933        batch1_mps = batch1_cpu.detach().clone().to("mps")
1934        batch2_mps = batch2_cpu.detach().clone().to("mps")
1935
1936        output_cpu = torch.bmm(batch1_cpu, batch2_cpu)
1937        output_mps = torch.bmm(batch1_mps, batch2_mps)
1938
1939        self.assertEqual(output_cpu, output_mps)
1940        self.assertEqual(output_cpu.size(), output_mps.size())
1941
1942    @xfailIf(product_version < 15.0)
1943    @parametrize("dtype", [torch.float16, torch.bfloat16])
1944    def test_large_bmm(self, dtype):
1945        batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps')
1946        batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps')
1947        output_cpu = torch.bmm(batch1.cpu(), batch2.cpu())
1948        output_mps = torch.bmm(batch1, batch2)
1949
1950        # Using the low precision comparison for FP16
1951        tol = 1e-2 if dtype == torch.float16 else None
1952        self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
1953        self.assertEqual(output_cpu.size(), output_mps.size())
1954
1955
1956    def test_addr(self):
1957        A = torch.ones(5, 10).to("mps")
1958        B = torch.ones(5).to("mps")
1959        C = torch.ones(10).to("mps")
1960        D = torch.addr(A, B, C).to("cpu")
1961        torch.testing.assert_close(D, torch.full((5, 10), 2.0))
1962
1963    def test_trace(self):
1964        M_cpu = torch.randn(3, 3)
1965        M_mps = M_cpu.detach().clone().to("mps")
1966
1967        output_cpu = torch.trace(M_cpu)
1968        output_mps = torch.trace(M_mps)
1969
1970        self.assertEqual(output_cpu, output_mps)
1971        self.assertEqual(output_cpu.size(), output_mps.size())
1972
1973    def test_addbmm(self):
1974        M_cpu = torch.randn(3, 5)
1975        batch1_cpu = torch.randn(10, 3, 4)
1976        batch2_cpu = torch.randn(10, 4, 5)
1977
1978        M_mps = M_cpu.detach().clone().to("mps")
1979        batch1_mps = batch1_cpu.detach().clone().to("mps")
1980        batch2_mps = batch2_cpu.detach().clone().to("mps")
1981
1982        output_cpu = torch.addbmm(M_cpu, batch1_cpu, batch2_cpu)
1983        output_mps = torch.addbmm(M_mps, batch1_mps, batch2_mps)
1984
1985        self.assertEqual(output_cpu, output_mps)
1986        self.assertEqual(output_cpu.size(), output_mps.size())
1987
1988    def test_baddbmm(self):
1989        def helper(input_shape, batch1_shape, batch2_shape):
1990            M_cpu = torch.randn(input_shape)
1991            batch1_cpu = torch.randn(batch1_shape)
1992            batch2_cpu = torch.randn(batch2_shape)
1993            alpha = 1.2
1994            beta = 0.8
1995
1996            M_mps = M_cpu.detach().clone().to("mps")
1997            batch1_mps = batch1_cpu.detach().clone().to("mps")
1998            batch2_mps = batch2_cpu.detach().clone().to("mps")
1999
2000            output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
2001            output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)
2002
2003            self.assertEqual(output_cpu, output_mps)
2004            self.assertEqual(output_cpu.size(), output_mps.size())
2005
2006        helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
2007        helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
2008        helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
2009
2010    def test_local_scalar_dense_mps(self):
2011        x_cpu = torch.randn(1)
2012        y_mps = x_cpu.to("mps")
2013        torch.testing.assert_close(x_cpu.item(), y_mps.item())
2014
2015    def test_linear_1d_weight(self):
2016        device = 'cpu'
2017        projected = torch.rand([8]).to(device)
2018        x = torch.rand([1, 2, 2, 8]).to(device)
2019        x_mps = x.to('mps')
2020        projected_mps = projected.to('mps')
2021        linear = F.linear(x, projected)
2022        linear_mps = F.linear(x_mps, projected_mps)
2023
2024        self.assertEqual(linear, linear_mps)
2025
2026        projected = torch.rand([1, 8]).to(device)
2027        x = torch.rand([1, 2, 2, 8]).to(device)
2028        x_mps = x.to('mps')
2029        projected_mps = projected.to('mps')
2030        linear = F.linear(x, projected)
2031        linear_mps = F.linear(x_mps, projected_mps)
2032
2033        self.assertEqual(linear, linear_mps)
2034
2035    def test_linear_bias(self):
2036        def helper(bias_shape):
2037            device = "cpu"
2038            x = torch.randn(2, 2, 2, 64, device=device)
2039            linear = torch.nn.Linear(64, 4, device=device)
2040            linear.bias = torch.nn.Parameter(torch.randn(bias_shape, dtype=torch.float32, device=device))
2041            y = linear(x)
2042            device = "mps"
2043            x_mps = x.to(device)
2044            linear.to(device)
2045            y_mps = linear(x_mps)
2046            self.assertEqual(y, y_mps)
2047
2048        helper(())
2049        helper((2, 4))
2050
2051    def test_linear_errors(self):
2052        # Mixed CPU<->MPS tensors
2053        size = (3, 3)
2054
2055        # Unsupported dtypes
2056        with self.assertRaisesRegex(RuntimeError, "does not support linear for non-float weights"):
2057            torch.nn.functional.linear(torch.rand(size, device='mps'),
2058                                       torch.randint(-10, 10, size, dtype=torch.int8, device='mps'))
2059
2060        # Weigths on wrong device
2061        with self.assertRaisesRegex(RuntimeError, "argument weight is on cpu but expected on mps"):
2062            torch.nn.functional.linear(torch.rand(size, device='mps'),
2063                                       torch.rand(size, device='cpu'))
2064
2065        # Input on wrong device
2066        with self.assertRaisesRegex(RuntimeError, "argument input is on cpu but expected on mps"):
2067            torch.nn.functional.linear(torch.rand(size, device='cpu'),
2068                                       torch.rand(size, device='mps'))
2069
2070    def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False):
2071        cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias)
2072        mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias)
2073
2074        # Use the same weights and bias as the ones from the cpu
2075        mps_linear.weight.data = cpu_linear.weight.data.detach().clone().to("mps")
2076
2077        if bias:
2078            mps_linear.bias.data = cpu_linear.bias.data.detach().clone().to("mps")
2079
2080        linear_mps_input = torch.randn(shape).to('mps')
2081        linear_cpu_input = linear_mps_input.detach().clone().to('cpu')
2082
2083        if backward_pass:
2084            linear_mps_input = linear_mps_input.requires_grad_()
2085            linear_cpu_input = linear_cpu_input.requires_grad_()
2086
2087        linear_cpu_output = cpu_linear(linear_cpu_input)
2088        linear_mps_output = mps_linear(linear_mps_input)
2089
2090        self.assertEqual(linear_cpu_output, linear_mps_output.to('cpu'))
2091        self.assertEqual(linear_cpu_output.size(), linear_mps_output.size())
2092
2093        if backward_pass:
2094            cpu_grad = torch.rand_like(linear_cpu_output, requires_grad=True)
2095            grad = cpu_grad.detach().to('mps').requires_grad_()
2096
2097            linear_cpu_output.backward(gradient=cpu_grad, create_graph=True)
2098            linear_mps_output.backward(gradient=grad, create_graph=True)
2099
2100            self.assertEqual(linear_cpu_input.grad.size(), linear_mps_input.grad.size())
2101            self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2102
2103            self.assertEqual(cpu_linear.weight.grad.size(), mps_linear.weight.grad.size())
2104            self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2105            if bias:
2106                self.assertEqual(cpu_linear.bias.grad.size(), mps_linear.bias.grad.size())
2107                self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad.to("cpu"), atol=8e-04, rtol=10.4e-05)
2108
2109            # test gradgrad
2110            x_grad_out = torch.rand_like(linear_cpu_input)
2111            x_grad_out_mps = x_grad_out.to("mps")
2112            w_grad_out = torch.rand_like(cpu_linear.weight)
2113            w_grad_out_mps = w_grad_out.to("mps")
2114
2115            linear_cpu_input.grad.detach().zero_()
2116            linear_mps_input.grad.detach().zero_()
2117            cpu_linear.weight.grad.detach().zero_()
2118            mps_linear.weight.grad.detach().zero_()
2119            if bias:
2120                b_grad_out = torch.rand_like(cpu_linear.bias)
2121                b_grad_out_mps = b_grad_out.to("mps")
2122                cpu_linear.bias.grad.detach().zero_()
2123                mps_linear.bias.grad.detach().zero_()
2124
2125            linear_cpu_input.grad.backward(x_grad_out, retain_graph=True)
2126            linear_mps_input.grad.backward(x_grad_out_mps, retain_graph=True)
2127            cpu_linear.weight.grad.backward(w_grad_out, retain_graph=True)
2128            mps_linear.weight.grad.backward(w_grad_out_mps, retain_graph=True)
2129            if bias:
2130                cpu_linear.bias.grad.backward(b_grad_out, retain_graph=True)
2131                mps_linear.bias.grad.backward(b_grad_out_mps, retain_graph=True)
2132
2133            self.assertEqual(cpu_grad.grad, grad.grad)
2134            self.assertEqual(linear_cpu_input.grad, linear_mps_input.grad)
2135            self.assertEqual(cpu_linear.weight.grad, mps_linear.weight.grad)
2136            if bias:
2137                self.assertEqual(cpu_linear.bias.grad, mps_linear.bias.grad)
2138
2139    def test_linear1D(self):
2140        self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=False)
2141
2142    def test_linear1D_backward(self):
2143        self._linear_helper(in_features=2, out_features=3, shape=([2]), bias=True, backward_pass=True)
2144
2145    def test_linear2D(self):
2146        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=False)
2147
2148    def test_linear2D_backward(self):
2149        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=True, backward_pass=True)
2150
2151    def test_linear2D_no_bias(self):
2152        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=False)
2153
2154    def test_linear2D_no_bias_backward(self):
2155        self._linear_helper(in_features=2, out_features=3, shape=((4, 2)), bias=False, backward_pass=True)
2156
2157    def test_linear3D(self):
2158        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
2159
2160    def test_linear3D_backward(self):
2161        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
2162
2163    def test_linear3D_no_bias(self):
2164        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=False)
2165
2166    def test_linear3D_no_bias_backward(self):
2167        self._linear_helper(in_features=2, out_features=3, shape=((4, 5, 2)), bias=True, backward_pass=True)
2168
2169    def test_uniform(self):
2170        low = torch.zeros(5, 5, requires_grad=True)
2171        high = (torch.ones(5, 5) * 3).requires_grad_()
2172        low_1d = torch.zeros(1, requires_grad=True)
2173        high_1d = (torch.ones(1) * 3).requires_grad_()
2174        self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
2175        self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
2176        self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
2177        self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
2178        self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
2179
2180        # Check log_prob computation when value outside range
2181        uniform = Uniform(low_1d, high_1d, validate_args=False)
2182        above_high = torch.tensor([4.0])
2183        below_low = torch.tensor([-1.0])
2184        self.assertEqual(uniform.log_prob(above_high).item(), -inf)
2185        self.assertEqual(uniform.log_prob(below_low).item(), -inf)
2186
2187        # check cdf computation when value outside range
2188        self.assertEqual(uniform.cdf(below_low).item(), 0)
2189        self.assertEqual(uniform.cdf(above_high).item(), 1)
2190
2191        state = torch.get_rng_state()
2192        rand = low.new(low.size()).uniform_()
2193        torch.set_rng_state(state)
2194        u = Uniform(low, high).rsample()
2195        u.backward(torch.ones_like(u))
2196        self.assertEqual(low.grad, 1 - rand)
2197        self.assertEqual(high.grad, rand)
2198        low.grad.zero_()
2199        high.grad.zero_()
2200
2201    def test_randperm(self, device="mps"):
2202        rng_device = None
2203        for n in (5, 100, 50000, 100000):
2204            for dtype in (torch.long, torch.half, torch.float):
2205                if n > 2049 and dtype == torch.half:  # Large n for torch.half will raise an exception, do not test here.
2206                    continue
2207                if n > 256 and dtype == torch.bfloat16:
2208                    continue
2209                with torch.random.fork_rng(devices=rng_device):
2210                    res1 = torch.randperm(n, dtype=dtype, device=device)
2211                res2 = torch.empty(0, dtype=dtype, device=device)
2212                torch.randperm(n, out=res2, dtype=dtype, device=device)
2213                self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
2214
2215        # Default type is long
2216        for n in (100, 10000):
2217            self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
2218
2219        # randperm of 0 elements is an empty tensor
2220        res1 = torch.randperm(0)
2221        res2 = torch.tensor(5, dtype=dtype, device=device)
2222        torch.randperm(0, out=res2)
2223        self.assertEqual(res1.numel(), 0)
2224        self.assertEqual(res2.numel(), 0)
2225
2226        # Test non-contiguous tensors
2227        for n in (4, 5, 6, 10, 20):
2228            non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
2229            self.assertFalse(non_contiguous_tensor.is_contiguous())
2230            with torch.random.fork_rng(devices=rng_device):
2231                res = torch.randperm(n, dtype=torch.long, device=device)
2232            torch.randperm(n, out=non_contiguous_tensor)
2233            self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
2234
2235    # Test forward maxpool2d
2236    def test_max_pool2d(self):
2237        def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
2238
2239            cpu_x = None
2240            if (test_ties):
2241                cpu_x = torch.ones(shape, device='cpu', dtype=torch.float, requires_grad=True)
2242            else:
2243                cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2244            x = cpu_x.detach().clone().to('mps').requires_grad_()
2245
2246            pool = torch.nn.MaxPool2d(kernel_size=ks, padding=padding, dilation=dilation,
2247                                      ceil_mode=ceil_mode, return_indices=return_indices)
2248
2249            if (return_indices is False):
2250                y = pool(x)
2251                ref_y = pool(cpu_x)
2252
2253                cpu_grad = torch.ones_like(ref_y)
2254                grad = cpu_grad.to('mps')
2255
2256                y.backward(gradient=grad)
2257                ref_y.backward(gradient=cpu_grad)
2258
2259                self.assertEqual(y, ref_y)
2260                self.assertEqual(x.grad, cpu_x.grad)
2261            else:
2262                y, idx = pool(x)
2263                ref_y, ref_idx = pool(cpu_x)
2264
2265                cpu_grad = torch.ones_like(ref_y)
2266                grad = cpu_grad.to('mps')
2267
2268                y.backward(gradient=grad)
2269                ref_y.backward(gradient=cpu_grad)
2270
2271                self.assertEqual(y, ref_y)
2272                self.assertEqual(idx, ref_idx)
2273                self.assertEqual(x.grad, cpu_x.grad)
2274
2275        # Test with no batch dimension
2276        helper((8, 4, 4), ks=2)
2277        helper((2, 8, 4, 4), ks=2)
2278        helper((1, 1000, 32, 32), ks=4)
2279        helper((1, 1000, 1, 4), ks=(1, 4))  # test for max_pool1d
2280        # Test padding
2281        helper((1, 1000, 32, 32), ks=4, padding=1)
2282        helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1))  # test for max_pool1d
2283        # Test dilation
2284        helper((1, 1000, 32, 32), ks=4, dilation=2)
2285        helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2))  # test for max_pool1d
2286        # Test ceil mode
2287        helper((1, 1000, 32, 32), ks=4, ceil_mode=True)
2288        helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True)  # test for max_pool1d
2289
2290        # Test return indices
2291        for test_ties in [False, True]:
2292            # Test with no batch dimension
2293            helper((8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
2294            helper((2, 8, 4, 4), ks=2, return_indices=True, test_ties=test_ties)
2295            helper((1, 1000, 32, 32), ks=4, return_indices=True, test_ties=test_ties)
2296            helper((1, 1000, 1, 4), ks=(1, 4), return_indices=True, test_ties=test_ties)  # test for max_pool1d
2297            # Test padding
2298            helper((1, 1000, 32, 32), ks=4, padding=1, return_indices=True, test_ties=test_ties)
2299            helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 1),
2300                   return_indices=True, test_ties=test_ties)  # test for max_pool1d
2301            # Test dilation
2302            helper((1, 1000, 32, 32), ks=4, dilation=2, return_indices=True, test_ties=test_ties)
2303            helper((1, 1000, 1, 4), ks=(1, 4), padding=(0, 2),
2304                   return_indices=True, test_ties=test_ties)  # test for max_pool1d
2305            # Test ceil mode
2306            helper((1, 1000, 32, 32), ks=4, ceil_mode=True, return_indices=True, test_ties=test_ties)
2307            helper((1, 1000, 1, 4), ks=(1, 4), ceil_mode=True,
2308                   return_indices=True, test_ties=test_ties)  # test for max_pool1d
2309
2310    def test_adaptive_avg_pool2d_output_size_one(self):
2311        def helper(size, memory_format):
2312            x = torch.randint(1, 10, size, dtype=torch.float, device='mps', requires_grad=True)
2313            if memory_format == 'non_contiguous':
2314                x = x[::2, ::2, ::2, ::2]
2315            else:
2316                x = x.to(memory_format=memory_format)
2317
2318            net = torch.nn.AdaptiveAvgPool2d((1, 1))
2319            out = net(x)
2320            ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
2321
2322            out.sum().backward()    # make sure it doesn't crash
2323
2324            self.assertEqual(out, ref_out)
2325            if memory_format == torch.channels_last:
2326                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
2327                c = out.size(1)
2328                self.assertEqual(out.stride(), [c, 1, c, c])
2329            else:
2330                self.assertTrue(out.is_contiguous())
2331                c = out.size(1)
2332                self.assertEqual(out.stride(), [c, 1, 1, 1])
2333
2334        helper((2, 3, 6, 6), torch.contiguous_format)
2335
2336    def test_masked_scatter(self):
2337        def helper(shape):
2338            x_mps = torch.randn(shape, device="mps")
2339            x_cpu = x_mps.detach().clone().cpu()
2340
2341            mask_mps = torch.rand(shape, device="mps") < 0.6
2342            mask_cpu = mask_mps.detach().clone().cpu()
2343
2344            y_mps = torch.randn(shape, device="mps")
2345            y_cpu = y_mps.detach().clone().cpu()
2346
2347            y_mps.masked_scatter_(mask_mps, x_mps)
2348            y_cpu.masked_scatter_(mask_cpu, x_cpu)
2349
2350            self.assertEqual(y_mps, y_cpu)
2351        helper([2, 5])
2352        helper([10, 10])
2353        helper([5, 10, 3])
2354        helper([10, 5, 10, 3])
2355        helper([10, 5, 10, 3, 20])
2356
2357    def test_masked_fill(self):
2358        device = "mps"
2359        dtype = torch.float32
2360        mask_dtype = torch.bool
2361        num_dest = 10
2362
2363        dst = torch.zeros(num_dest, dtype=dtype, device=device)
2364        mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
2365        val = random.random()
2366        dst2 = torch.zeros(num_dest, dtype=dtype)
2367        mask_cpu = mask.to("cpu")
2368
2369        dst.masked_fill_(mask, val)
2370        for i in range(num_dest):
2371            if mask_cpu[i]:
2372                dst2[i] = val
2373        self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
2374
2375    def test_masked_fill__non_contiguous(self):
2376        shape = (3, 5)
2377
2378        x_mps = torch.randn(shape, device="mps")
2379        x_cpu = x_mps.detach().clone().cpu()
2380        mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
2381        mask_cpu = mask_mps.detach().clone().cpu()
2382
2383        x_mps_strided = x_mps.T
2384        x_cpu_strided = x_cpu.T
2385
2386        x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
2387        x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
2388
2389        self.assertEqual(x_mps_strided, x_cpu_strided)
2390        self.assertFalse((x_mps_strided == float("-inf")).any())
2391
2392    def test_nhwc_operation(self):
2393        def helper(shape, channels_last=False):
2394            import numpy as np
2395            np.random.seed(332)
2396            arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2397            cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
2398            if (channels_last):
2399                cpu_x = cpu_x.to(memory_format=torch.channels_last)
2400                cpu_x.retain_grad()
2401            x = cpu_x.detach().clone().to('mps').requires_grad_()
2402
2403            # This passes
2404            self.assertEqual(x, cpu_x)
2405
2406        helper((2, 2, 2, 2), True)
2407
2408    # Test forward batch norm
2409    def test_batch_norm(self):
2410        def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last=False,
2411                   track_running_stats=True, test_module=False):
2412
2413            import numpy as np
2414            np.random.seed(332)
2415            arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2416            cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
2417            if (channels_last):
2418                cpu_x = cpu_x.to(memory_format=torch.channels_last)
2419                cpu_x.retain_grad()
2420            x = cpu_x.detach().clone().to('mps').requires_grad_()
2421
2422            mean_shape = [shape[1]]
2423            cpu_running_mean = None
2424            cpu_running_var = None
2425            running_mean = None
2426            running_var = None
2427            if (track_running_stats):
2428                mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2429                cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2430                var_arr = 32 * np.random.random_sample(size=mean_shape)
2431                cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2432                running_mean = cpu_running_mean.detach().clone().to('mps')
2433                running_var = cpu_running_var.detach().clone().to('mps')
2434
2435            weight = None
2436            cpu_weight = None
2437            bias = None
2438            cpu_bias = None
2439            if (wts):
2440                cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2441                weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2442                cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2443                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2444
2445            y = None
2446            ref_y = None
2447
2448            if (not test_module):
2449                y = torch.nn.functional.batch_norm(x, running_mean, running_var,
2450                                                   weight=weight,
2451                                                   bias=bias,
2452                                                   training=training,
2453                                                   momentum=momentum, eps=eps)
2454                ref_y = torch.nn.functional.batch_norm(cpu_x, cpu_running_mean, cpu_running_var,
2455                                                       weight=cpu_weight,
2456                                                       bias=cpu_bias,
2457                                                       training=training,
2458                                                       momentum=momentum, eps=eps)
2459
2460            else:
2461
2462                batchnorm_op = None
2463                mps_batchnorm_op = None
2464
2465                if (len(shape) == 3):
2466                    batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2467                                                        eps=eps,
2468                                                        momentum=momentum,
2469                                                        affine=wts,
2470                                                        track_running_stats=track_running_stats,
2471                                                        device='cpu')
2472                    mps_batchnorm_op = torch.nn.BatchNorm1d(shape[1],
2473                                                            eps=eps,
2474                                                            momentum=momentum,
2475                                                            affine=wts,
2476                                                            track_running_stats=track_running_stats,
2477                                                            device='mps')
2478                elif (len(shape) == 4):
2479                    batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2480                                                        eps=eps,
2481                                                        momentum=momentum,
2482                                                        affine=wts,
2483                                                        track_running_stats=track_running_stats,
2484                                                        device='cpu')
2485                    mps_batchnorm_op = torch.nn.BatchNorm2d(shape[1],
2486                                                            eps=eps,
2487                                                            momentum=momentum,
2488                                                            affine=wts,
2489                                                            track_running_stats=track_running_stats,
2490                                                            device='mps')
2491                elif (len(shape) == 5):
2492                    batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2493                                                        eps=eps,
2494                                                        momentum=momentum,
2495                                                        affine=wts,
2496                                                        track_running_stats=track_running_stats,
2497                                                        device='cpu')
2498                    mps_batchnorm_op = torch.nn.BatchNorm3d(shape[1],
2499                                                            eps=eps,
2500                                                            momentum=momentum,
2501                                                            affine=wts,
2502                                                            track_running_stats=track_running_stats,
2503                                                            device='mps')
2504
2505                if (track_running_stats):
2506                    batchnorm_op.running_mean = cpu_running_mean
2507                    batchnorm_op.running_var = cpu_running_var
2508                    mps_batchnorm_op.running_mean = running_mean
2509                    mps_batchnorm_op.running_var = running_var
2510                if (wts):
2511                    batchnorm_op.weight = torch.nn.Parameter(cpu_weight)
2512                    batchnorm_op.bias = torch.nn.Parameter(cpu_bias)
2513                    mps_batchnorm_op.weight = torch.nn.Parameter(weight)
2514                    mps_batchnorm_op.bias = torch.nn.Parameter(bias)
2515
2516                ref_y = batchnorm_op(cpu_x)
2517                y = mps_batchnorm_op(x)
2518
2519            self.assertEqual(y, ref_y)
2520            if (not test_module):
2521                self.assertEqual(running_mean, cpu_running_mean)
2522                self.assertEqual(running_var, cpu_running_var)
2523            else:
2524                self.assertEqual(mps_batchnorm_op.running_mean, batchnorm_op.running_mean)
2525                self.assertEqual(mps_batchnorm_op.running_var, batchnorm_op.running_var)
2526
2527            cpu_grad = torch.randn(ref_y.shape)
2528            grad = cpu_grad.to('mps')
2529            ref_y.backward(gradient=cpu_grad)
2530            y.backward(gradient=grad)
2531
2532            self.assertEqual(x.grad, cpu_x.grad)
2533            if (wts):
2534                if (not test_module):
2535                    self.assertEqual(weight.grad, cpu_weight.grad)
2536                    self.assertEqual(bias.grad, cpu_bias.grad)
2537                else:
2538                    self.assertEqual(mps_batchnorm_op.weight.grad, batchnorm_op.weight.grad)
2539                    self.assertEqual(mps_batchnorm_op.bias.grad, batchnorm_op.bias.grad)
2540
2541        for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2542            for test_module in [False, True]:
2543                for track_running_stats in [True, False]:
2544                    for channels_last in [False]:
2545                        if (channels_last and len(shape) != 4):
2546                            continue
2547                        # Running stats must be tracked in eval mode
2548                        if (track_running_stats):
2549                            helper(shape, eps=0, momentum=1, channels_last=channels_last,
2550                                   track_running_stats=track_running_stats, test_module=test_module)
2551                            helper(shape, channels_last=channels_last,
2552                                   track_running_stats=track_running_stats, test_module=test_module)
2553                            helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last,
2554                                   track_running_stats=track_running_stats, test_module=test_module)
2555                            helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last,
2556                                   track_running_stats=track_running_stats, test_module=test_module)
2557                            helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last,
2558                                   track_running_stats=track_running_stats, test_module=test_module)
2559                            helper(shape, eps=3, momentum=0.67, wts=True, training=False, channels_last=channels_last,
2560                                   track_running_stats=track_running_stats, test_module=test_module)
2561                        helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last,
2562                               track_running_stats=track_running_stats, test_module=test_module)
2563                        helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last,
2564                               track_running_stats=track_running_stats, test_module=test_module)
2565                        helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last,
2566                               track_running_stats=track_running_stats, test_module=test_module)
2567                        helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
2568                               track_running_stats=track_running_stats, test_module=test_module)
2569
2570    def test_batch_norm_backward(self):
2571        inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
2572        x = torch.nn.BatchNorm2d(8).to("mps")
2573        y = torch.nn.BatchNorm2d(8).to("mps")
2574        y.weight.requires_grad = False
2575        y.bias.requires_grad = False
2576        outputs = y(x(inputs))
2577        # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2578        outputs.sum().backward()
2579
2580    # Regression test for https://github.com/pytorch/pytorch/issues/133520
2581    def test_batch_norm_slices(self):
2582        bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
2583        bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
2584
2585        x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
2586        x_mps = x_cpu.to('mps')
2587
2588        res_cpu = bn_cpu(x_cpu[5:])
2589        res_mps = bn_mps(x_mps[5:])
2590
2591        self.assertEqual(res_cpu, res_mps)
2592
2593    def test_layer_norm_backward(self):
2594        inputs = torch.rand(4, 4, device="mps", requires_grad=True)
2595        x = torch.nn.LayerNorm(4).to("mps")
2596        y = torch.nn.LayerNorm(4).to("mps")
2597        y.weight.requires_grad = False
2598        y.bias.requires_grad = False
2599        outputs = y(x(inputs))
2600        # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2601        outputs.sum().backward()
2602
2603    def test_norm(self):
2604        a = torch.arange(9, dtype=torch.float, device="mps") - 4
2605        b = a.reshape((3, 3))
2606
2607        a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
2608        b_cpu = a_cpu.reshape((3, 3))
2609
2610        res = torch.norm(a)
2611        res_cpu = torch.norm(a_cpu)
2612        self.assertEqual(res, res_cpu)
2613
2614        res = torch.norm(b)
2615        res_cpu = torch.norm(b_cpu)
2616        self.assertEqual(res, res_cpu)
2617
2618        res = torch.norm(a, float('inf'))
2619        res_cpu = torch.norm(a_cpu, float('inf'))
2620        self.assertEqual(res, res_cpu)
2621
2622        res = torch.norm(b, float('inf'))
2623        res_cpu = torch.norm(b_cpu, float('inf'))
2624        self.assertEqual(res, res_cpu)
2625
2626        c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
2627        c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
2628
2629        res = torch.norm(c, dim=0)
2630        res_cpu = torch.norm(c_cpu, dim=0)
2631        self.assertEqual(res, res_cpu)
2632
2633        res = torch.norm(c, dim=1)
2634        res_cpu = torch.norm(c_cpu, dim=1)
2635        self.assertEqual(res, res_cpu)
2636
2637        res = torch.norm(c, p=1, dim=1)
2638        res_cpu = torch.norm(c_cpu, p=1, dim=1)
2639        self.assertEqual(res, res_cpu)
2640
2641        d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
2642        d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
2643
2644        res = torch.norm(d, dim=(1, 2))
2645        res_cpu = torch.norm(d_cpu, dim=(1, 2))
2646        self.assertEqual(res, res_cpu)
2647
2648        res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
2649        res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
2650        self.assertEqual(res, res_cpu)
2651
2652    def test_linalg_vector_norm(self):
2653        x_mps = torch.tensor([0, 0, 0, 2, 3], dtype=torch.float, device="mps")
2654        x_cpu = x_mps.detach().clone().cpu()
2655
2656        res_mps = torch.linalg.vector_norm(x_mps, ord=0)
2657        res_cpu = torch.linalg.vector_norm(x_cpu, ord=0)
2658        self.assertEqual(res_mps, res_cpu)
2659
2660        a_mps = torch.arange(27, dtype=torch.float, device="mps") - 4
2661        a_cpu = torch.arange(27, dtype=torch.float, device="cpu") - 4
2662
2663        B_mps = a_mps.reshape(3, 3, 3)
2664        B_cpu = a_cpu.reshape(3, 3, 3)
2665
2666        res_mps = torch.linalg.vector_norm(a_mps, ord=3.5)
2667        res_cpu = torch.linalg.vector_norm(a_cpu, ord=3.5)
2668        self.assertEqual(res_mps, res_cpu)
2669
2670        res_mps = torch.linalg.vector_norm(B_mps, ord=3.5)
2671        res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5)
2672        self.assertEqual(res_mps, res_cpu)
2673
2674        for dim in range(0, B_mps.dim()):
2675            res_mps = torch.linalg.vector_norm(B_mps, ord=3.5, dim=dim)
2676            res_cpu = torch.linalg.vector_norm(B_cpu, ord=3.5, dim=dim)
2677            self.assertEqual(res_mps, res_cpu)
2678
2679
2680    def test_layer_norm(self):
2681        # TODO: Test non-contiguous
2682        def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):
2683            cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
2684            x = cpu_x.detach().clone().to('mps').requires_grad_()
2685
2686            cpu_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='cpu', dtype=dtype)
2687            mps_op = torch.nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device='mps', dtype=dtype)
2688            cpu_wt = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2689            wt = cpu_wt.detach().clone().to('mps').requires_grad_()
2690            cpu_bias = torch.randn(normalized_shape, device='cpu', dtype=dtype, requires_grad=True)
2691            bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2692
2693            if (elementwise_affine):
2694                cpu_op.weight = torch.nn.Parameter(cpu_wt)
2695                mps_op.weight = torch.nn.Parameter(wt)
2696                cpu_op.bias = torch.nn.Parameter(cpu_bias)
2697                mps_op.bias = torch.nn.Parameter(bias)
2698
2699            cpu_result = cpu_op(cpu_x)
2700            result = mps_op(x)
2701
2702            cpu_grad = torch.randn(cpu_result.shape)
2703            grad = cpu_grad.to('mps')
2704
2705            cpu_result.backward(cpu_grad)
2706            result.backward(grad)
2707
2708            self.assertEqual(result, cpu_result)
2709            self.assertEqual(x.grad, cpu_x.grad)
2710            if (elementwise_affine):
2711                self.assertEqual(mps_op.weight.grad, cpu_op.weight.grad)
2712                self.assertEqual(mps_op.bias.grad, cpu_op.bias.grad)
2713
2714        for elementwise_affine in [True, False]:
2715            helper((2, 2, 2, 2), (2, 2), elementwise_affine=elementwise_affine)
2716            helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
2717            helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
2718
2719        # Regression test for https://github.com/pytorch/pytorch/issues/96113
2720        torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
2721
2722    @xfailIf(product_version < 14.0)
2723    def test_ifft(self):
2724        # See: https://github.com/pytorch/pytorch/issues/124096
2725        device = torch.device("mps")
2726
2727        N = 64
2728        signal = torch.rand(N, device=device)
2729        fft_result = torch.fft.rfft(signal)
2730        ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0])
2731
2732        # Expecting the inverted to yield the original signal
2733        self.assertEqual(ifft_result, signal)
2734
2735    # Regression test for https://github.com/pytorch/pytorch/issues/135223
2736    def test_fftfreq(self):
2737        freq_cpu = torch.fft.fftfreq(10**4, device='cpu')
2738        freq_mps = torch.fft.fftfreq(10**4, device='mps')
2739        self.assertEqual(freq_cpu, freq_mps)
2740
2741    def test_instance_norm(self):
2742        def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
2743
2744            import numpy as np
2745            np.random.seed(332)
2746            arr = (256 - 128) * np.random.random_sample(size=shape) + 128
2747            cpu_x = torch.tensor(arr, device='cpu', dtype=torch.float, requires_grad=True)
2748            if (channels_last):
2749                cpu_x = cpu_x.to(memory_format=torch.channels_last)
2750                cpu_x.retain_grad()
2751            x = cpu_x.detach().clone().to('mps').requires_grad_()
2752
2753            mean_shape = [shape[1]]
2754            cpu_running_mean = None
2755            cpu_running_var = None
2756            running_mean = None
2757            running_var = None
2758            if (track_running_stats):
2759                mean_arr = (240 - 140) * np.random.random_sample(size=mean_shape) + 140
2760                cpu_running_mean = torch.tensor(mean_arr, device='cpu', dtype=torch.float)
2761                var_arr = 32 * np.random.random_sample(size=mean_shape)
2762                cpu_running_var = torch.tensor(var_arr, device='cpu', dtype=torch.float)
2763                running_mean = cpu_running_mean.detach().clone().to('mps')
2764                running_var = cpu_running_var.detach().clone().to('mps')
2765
2766            weight = None
2767            cpu_weight = None
2768            bias = None
2769            cpu_bias = None
2770            if (wts):
2771                cpu_weight = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2772                weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2773                cpu_bias = torch.randn(mean_shape, device='cpu', dtype=torch.float, requires_grad=True)
2774                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2775
2776            y = None
2777            ref_y = None
2778
2779            if (not test_module):
2780                ref_y = torch.nn.functional.instance_norm(cpu_x, cpu_running_mean, cpu_running_var,
2781                                                          weight=cpu_weight,
2782                                                          bias=cpu_bias,
2783                                                          momentum=momentum, eps=eps)
2784                y = torch.nn.functional.instance_norm(x, running_mean, running_var,
2785                                                      weight=weight,
2786                                                      bias=bias,
2787                                                      momentum=momentum, eps=eps)
2788
2789            else:
2790
2791                instancenorm_op = None
2792                mps_instancenorm_op = None
2793
2794                if (len(shape) == 3):
2795                    instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2796                                                              eps=eps,
2797                                                              momentum=momentum,
2798                                                              affine=wts,
2799                                                              track_running_stats=track_running_stats,
2800                                                              device='cpu')
2801                    mps_instancenorm_op = torch.nn.InstanceNorm1d(shape[1],
2802                                                                  eps=eps,
2803                                                                  momentum=momentum,
2804                                                                  affine=wts,
2805                                                                  track_running_stats=track_running_stats,
2806                                                                  device='mps')
2807                elif (len(shape) == 4):
2808                    instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2809                                                              eps=eps,
2810                                                              momentum=momentum,
2811                                                              affine=wts,
2812                                                              track_running_stats=track_running_stats,
2813                                                              device='cpu')
2814                    mps_instancenorm_op = torch.nn.InstanceNorm2d(shape[1],
2815                                                                  eps=eps,
2816                                                                  momentum=momentum,
2817                                                                  affine=wts,
2818                                                                  track_running_stats=track_running_stats,
2819                                                                  device='mps')
2820                elif (len(shape) == 5):
2821                    instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2822                                                              eps=eps,
2823                                                              momentum=momentum,
2824                                                              affine=wts,
2825                                                              track_running_stats=track_running_stats,
2826                                                              device='cpu')
2827                    mps_instancenorm_op = torch.nn.InstanceNorm3d(shape[1],
2828                                                                  eps=eps,
2829                                                                  momentum=momentum,
2830                                                                  affine=wts,
2831                                                                  track_running_stats=track_running_stats,
2832                                                                  device='mps')
2833
2834                if (track_running_stats):
2835                    instancenorm_op.running_mean = cpu_running_mean
2836                    instancenorm_op.running_var = cpu_running_var
2837                    mps_instancenorm_op.running_mean = running_mean
2838                    mps_instancenorm_op.running_var = running_var
2839                if (wts):
2840                    instancenorm_op.weight = torch.nn.Parameter(cpu_weight)
2841                    instancenorm_op.bias = torch.nn.Parameter(cpu_bias)
2842                    mps_instancenorm_op.weight = torch.nn.Parameter(weight)
2843                    mps_instancenorm_op.bias = torch.nn.Parameter(bias)
2844
2845                ref_y = instancenorm_op(cpu_x)
2846                y = mps_instancenorm_op(x)
2847
2848            self.assertEqual(y, ref_y)
2849            if (not test_module):
2850                self.assertEqual(running_mean, cpu_running_mean)
2851                self.assertEqual(running_var, cpu_running_var)
2852            else:
2853                self.assertEqual(mps_instancenorm_op.running_mean, instancenorm_op.running_mean)
2854                self.assertEqual(mps_instancenorm_op.running_var, instancenorm_op.running_var)
2855
2856            cpu_grad = torch.randn(ref_y.shape)
2857            grad = cpu_grad.to('mps')
2858            ref_y.backward(gradient=cpu_grad)
2859            y.backward(gradient=grad)
2860
2861            self.assertEqual(x.grad, cpu_x.grad)
2862            if (wts):
2863                if (not test_module):
2864                    self.assertEqual(weight.grad, cpu_weight.grad)
2865                    self.assertEqual(bias.grad, cpu_bias.grad)
2866                else:
2867                    self.assertEqual(mps_instancenorm_op.weight.grad, instancenorm_op.weight.grad)
2868                    self.assertEqual(mps_instancenorm_op.bias.grad, instancenorm_op.bias.grad)
2869
2870        for shape in [(2, 3, 2, 2), (2, 3, 2, 2, 2), (2, 3, 2)]:
2871            for test_module in [False, True]:
2872                for track_running_stats in [True, False]:
2873                    for channels_last in [False]:
2874                        if (channels_last and len(shape) != 4):
2875                            continue
2876                        # Running stats must be tracked in eval mode
2877                        if (track_running_stats):
2878                            helper(shape, eps=0, momentum=1, channels_last=channels_last,
2879                                   track_running_stats=track_running_stats, test_module=test_module)
2880                            helper(shape, channels_last=channels_last,
2881                                   track_running_stats=track_running_stats, test_module=test_module)
2882                            helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2883                                   track_running_stats=track_running_stats, test_module=test_module)
2884                            helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2885                                   track_running_stats=track_running_stats, test_module=test_module)
2886                            helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2887                                   track_running_stats=track_running_stats, test_module=test_module)
2888                            helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2889                                   track_running_stats=track_running_stats, test_module=test_module)
2890                        helper(shape, eps=1e-05, momentum=0.1, wts=False, channels_last=channels_last,
2891                               track_running_stats=track_running_stats, test_module=test_module)
2892                        helper(shape, eps=0, momentum=1.0, wts=False, channels_last=channels_last,
2893                               track_running_stats=track_running_stats, test_module=test_module)
2894                        helper(shape, eps=1, momentum=1, wts=True, channels_last=channels_last,
2895                               track_running_stats=track_running_stats, test_module=test_module)
2896                        helper(shape, eps=3, momentum=0.67, wts=True, channels_last=channels_last,
2897                               track_running_stats=track_running_stats, test_module=test_module)
2898
2899    def test_weight_norm(self):
2900        def validate_weight_norm_equality(model, cpu_model, x, cpu_x, dim):
2901            cpu_norm = torch.nn.utils.parametrizations.weight_norm(cpu_model, dim=dim)
2902            norm = torch.nn.utils.parametrizations.weight_norm(model, dim=dim)
2903
2904            cpu_out = cpu_norm(cpu_x)
2905            out = norm(x)
2906
2907            self.assertEqual(cpu_out, out)
2908
2909            cpu_grad = torch.randn(cpu_out.shape)
2910            grad = cpu_grad.to('mps')
2911            cpu_out.backward(gradient=cpu_grad)
2912            out.backward(gradient=grad)
2913
2914            self.assertEqual(cpu_model.parametrizations.weight.original0.grad, model.parametrizations.weight.original0.grad)
2915            self.assertEqual(cpu_model.parametrizations.weight.original1.grad, model.parametrizations.weight.original1.grad)
2916
2917            self.assertEqual(x.grad, cpu_x.grad)
2918
2919        def helper(dim, layer='linear', dtype=torch.float32):
2920            # linear layer
2921            if layer == 'linear':
2922                cpu_x = torch.randn((2, 5), device='cpu', dtype=dtype, requires_grad=True)
2923                x = cpu_x.detach().clone().to('mps').requires_grad_()
2924
2925                cpu_weight = torch.randn(10, 5, device='cpu', dtype=dtype, requires_grad=True)
2926                weight = cpu_weight.detach().clone().to('mps').requires_grad_()
2927
2928                cpu_bias = torch.randn(10, device='cpu', dtype=dtype, requires_grad=True)
2929                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
2930
2931                cpu_linear = torch.nn.Linear(5, 10, device='cpu')
2932                linear = torch.nn.Linear(5, 10, device='mps')
2933
2934                with torch.no_grad():
2935                    cpu_linear.weight.copy_(cpu_weight)
2936                    cpu_linear.bias.copy_(cpu_bias)
2937                    linear.weight.copy_(weight)
2938                    linear.bias.copy_(bias)
2939                validate_weight_norm_equality(linear, cpu_linear, x, cpu_x, dim)
2940
2941            # conv layer
2942            if layer == 'conv':
2943                cpu_x = torch.randn((3, 5, 5), device='cpu', dtype=dtype, requires_grad=True)
2944                x = cpu_x.detach().clone().to('mps').requires_grad_()
2945
2946                cpu_conv = torch.nn.Conv2d(3, 3, 3, device='cpu')
2947                conv = torch.nn.Conv2d(3, 3, 3, device='mps')
2948
2949                with torch.no_grad():
2950                    conv.weight.copy_(cpu_conv.weight)
2951                    conv.bias.copy_(cpu_conv.bias)
2952
2953                validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
2954
2955            # conv3d layer
2956            if layer == 'conv3d':
2957                cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
2958                x = cpu_x.detach().clone().to('mps').requires_grad_()
2959
2960                cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
2961                conv = torch.nn.Conv3d(3, 3, 3, device='mps')
2962
2963                with torch.no_grad():
2964                    conv.weight.copy_(cpu_conv.weight)
2965                    conv.bias.copy_(cpu_conv.bias)
2966
2967                validate_weight_norm_equality(conv, cpu_conv, x, cpu_x, dim)
2968
2969        helper(0, layer='linear')
2970        helper(1, layer='linear')
2971        helper(-1, layer='linear')
2972
2973        helper(0, layer='conv')
2974        helper(1, layer='conv')
2975        helper(2, layer='conv')
2976        helper(3, layer='conv')
2977        helper(-1, layer='conv')
2978
2979        if product_version >= 13.2:
2980            # Conv3d is only available from MacOS 13 onwards
2981            helper(0, layer='conv3d')
2982            helper(1, layer='conv3d')
2983            helper(2, layer='conv3d')
2984            helper(3, layer='conv3d')
2985            helper(4, layer='conv3d')
2986            helper(-1, layer='conv3d')
2987
2988    # Test conv2d
2989    def test_conv2d_unit(self):
2990        def helper(input_shape, wt_shape,
2991                   stride=1, padding=0,
2992                   dilation=1, groups=1,
2993                   bias_shape=None):
2994
2995            cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
2996            x = cpu_x.detach().clone().to('mps').requires_grad_()
2997
2998            cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
2999            wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3000
3001            cpu_bias = None
3002            bias = None
3003
3004            if (bias_shape is not None):
3005                cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
3006                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3007
3008            y = torch.nn.functional.conv2d(x, wt, bias=bias, stride=stride,
3009                                           padding=padding, dilation=dilation, groups=groups)
3010            ref_y = torch.nn.functional.conv2d(cpu_x, cpu_wt, bias=cpu_bias, stride=stride,
3011                                               padding=padding, dilation=dilation, groups=groups)
3012
3013            cpu_grad = torch.ones_like(ref_y)
3014            grad = cpu_grad.to('mps')
3015
3016            y.backward(gradient=grad)
3017            ref_y.backward(gradient=cpu_grad)
3018
3019            self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
3020            self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
3021            self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
3022            if (bias_shape is not None):
3023                self.assertEqual(bias.grad, cpu_bias.grad, atol=8e-04, rtol=10.4e-05)
3024
3025        N = 1
3026        C_in = 3
3027        C_out = 64
3028        H = 64
3029        W = 64
3030        kH = 4
3031        kW = 4
3032        stride = 2
3033        padding = 1
3034
3035        helper((N, C_in, H, W), (C_out, C_in, kH, kW), stride=stride, padding=padding)
3036
3037        N = 4
3038        C_in = 16
3039        H = 32
3040        W = 32
3041
3042        C_out = 8
3043        kH = 3
3044        kW = 3
3045
3046        for groups in [1, 2, 4]:
3047            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3048            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), groups=groups)
3049
3050            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3051            helper((N, C_in, H, W), (C_out, C_in // groups, kH, kW), bias_shape=(C_out), groups=groups)
3052
3053            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3054            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups, kH + 2, kW + 2), groups=groups)
3055
3056            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3057                   kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3058            helper((N, C_in * 2, H * 2, W * 2), (C_out * 2, (C_in * 2) // groups,
3059                   kH + 2, kW + 2), bias_shape=(C_out * 2), groups=groups)
3060
3061    # Test conv transpose 2d
3062    def test_conv_transpose2d(self):
3063        def helper(input_shape, wt_shape,
3064                   stride=1, padding=0,
3065                   output_padding=0,
3066                   dilation=1, groups=1,
3067                   bias_shape=None):
3068
3069            cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
3070            x = cpu_x.detach().clone().to('mps').requires_grad_()
3071
3072            cpu_wt = torch.randn(wt_shape, device='cpu', dtype=torch.float, requires_grad=True)
3073            wt = cpu_wt.detach().clone().to('mps').requires_grad_()
3074
3075            cpu_bias = None
3076            bias = None
3077
3078            if (bias_shape is not None):
3079                cpu_bias = torch.randn(bias_shape, device='cpu', dtype=torch.float, requires_grad=True)
3080                bias = cpu_bias.detach().clone().to('mps').requires_grad_()
3081
3082            y = torch.nn.functional.conv_transpose2d(
3083                x, wt, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
3084            ref_y = torch.nn.functional.conv_transpose2d(
3085                cpu_x, cpu_wt, bias=cpu_bias, stride=stride, padding=padding,
3086                output_padding=output_padding, groups=groups, dilation=dilation)
3087
3088            cpu_grad = torch.randn(ref_y.shape)
3089            grad = cpu_grad.to('mps')
3090
3091            y.backward(gradient=grad)
3092            ref_y.backward(gradient=cpu_grad)
3093
3094            self.assertEqual(y, ref_y, rtol=2.6e-05, atol=2e-04)
3095            self.assertEqual(x.grad, cpu_x.grad, rtol=2.6e-06, atol=2e-05)
3096            self.assertEqual(wt.grad, cpu_wt.grad, atol=8e-04, rtol=10.4e-05)
3097
3098            # if (bias_shape is not None):
3099            #  print(cpu_bias.grad)
3100            #  print(bias.grad.to('cpu'))
3101            #  self.assertEqual(bias.grad, cpu_bias.grad)
3102
3103        N = 4
3104        C_in = 2
3105        H = 32
3106        W = 32
3107
3108        C_out = 8
3109        groups = 1
3110        kH = 3
3111        kW = 3
3112
3113        for stride in [1, 2, 3]:
3114            for padding in [0, 1, 2]:
3115                for output_padding in [0, 1, 2]:
3116                    for dilation in [1, 2]:
3117                        if (output_padding >= stride or output_padding >= dilation):
3118                            continue
3119                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3120                               padding=padding, output_padding=output_padding, dilation=dilation)
3121                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), stride=stride,
3122                               padding=padding, output_padding=output_padding, dilation=dilation)
3123
3124                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3125                               padding=padding, output_padding=output_padding, dilation=dilation)
3126                        helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
3127                               padding=padding, output_padding=output_padding, dilation=dilation)
3128
3129    # Test sigmoid
3130    def test_sigmoid(self):
3131        def helper(shape):
3132
3133            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3134            x = cpu_x.detach().clone().to('mps').requires_grad_()
3135
3136            sigmoid_op = torch.nn.Sigmoid()
3137
3138            y = sigmoid_op(x)
3139            ref_y = sigmoid_op(cpu_x)
3140
3141            cpu_grad = torch.ones_like(ref_y)
3142            grad = cpu_grad.to('mps')
3143
3144            y.backward(gradient=grad)
3145            ref_y.backward(gradient=cpu_grad)
3146
3147            self.assertEqual(y, ref_y)
3148            self.assertEqual(x.grad, cpu_x.grad)
3149
3150        helper((2, 3, 4, 5))
3151        helper((2, 3, 4))
3152        helper((2, 8, 4, 5))
3153
3154    # Test tanh
3155    def test_tanh(self):
3156        def helper(shape):
3157
3158            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3159            x = cpu_x.detach().clone().to('mps').requires_grad_()
3160
3161            tanh_op = torch.nn.Tanh()
3162
3163            y = tanh_op(x)
3164            ref_y = tanh_op(cpu_x)
3165
3166            cpu_grad = torch.ones_like(ref_y)
3167            grad = cpu_grad.to('mps')
3168
3169            y.backward(gradient=grad)
3170            ref_y.backward(gradient=cpu_grad)
3171
3172            self.assertEqual(y, ref_y)
3173            self.assertEqual(x.grad, cpu_x.grad)
3174
3175        helper((2, 3, 4, 5))
3176        helper((2, 3, 4))
3177        helper((2, 8, 4, 5))
3178
3179    def test_threshold(self):
3180        def helper(threshold, value, num_elems, inplace=False, requires_grad=True):
3181            m = nn.Threshold(threshold=threshold, value=value, inplace=inplace)
3182
3183            input_cpu = torch.randn(num_elems, requires_grad=requires_grad, dtype=torch.float)
3184            input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
3185
3186            output_cpu = m(input_cpu)
3187            output_mps = m(input_mps)
3188
3189            cpu_grad = torch.ones_like(output_cpu)
3190            mps_grad = cpu_grad.to('mps')
3191
3192            self.assertEqual(output_cpu, output_mps)
3193
3194            if requires_grad:
3195                output_cpu.backward(gradient=cpu_grad)
3196                output_mps.backward(gradient=mps_grad)
3197
3198                self.assertEqual(input_cpu.grad, input_mps.grad)
3199
3200        helper(threshold=0.1, value=20, num_elems=2)
3201        helper(threshold=-0.1, value=10, num_elems=10)
3202        helper(threshold=0.5, value=-15, num_elems=100)
3203        helper(threshold=1, value=10, num_elems=100, inplace=True, requires_grad=False)
3204
3205    # Test pow
3206    def test_pow(self):
3207        def helper(shape):
3208            # aten::pow.Tensor_Tensor
3209            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3210            x = cpu_x.detach().clone().to('mps')
3211            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3212            y = cpu_y.detach().clone().to('mps')
3213            z = torch.pow(x, y)
3214            ref_z = torch.pow(cpu_x, cpu_y)
3215
3216            self.assertEqual(z, ref_z)
3217
3218            # aten::pow.Tensor_Scalar
3219            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3220            x = cpu_x.detach().clone().to('mps')
3221            exp = random.random()
3222            z = torch.pow(x, exp)
3223            ref_z = torch.pow(cpu_x, exp)
3224
3225            self.assertEqual(z, ref_z)
3226
3227            # aten::pow.Scalar
3228            x = random.random()
3229            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3230            y = cpu_y.detach().clone().to('mps')
3231            z = torch.pow(x, y)
3232            ref_z = torch.pow(x, cpu_y)
3233
3234            self.assertEqual(z, ref_z)
3235
3236        helper((2, 8, 4, 5))
3237
3238    # Test addcmul
3239    def test_addcmul(self):
3240        def helper(shape, value, xtype=torch.float32, ytype=None, ztype=None):
3241            def rand_helper(dtype):
3242                if dtype.is_floating_point:
3243                    return torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
3244                return torch.randint(10, shape, dtype=dtype, device='cpu', requires_grad=False)
3245
3246            cpu_x = rand_helper(xtype)
3247            x = cpu_x.detach().clone().to('mps')
3248
3249            cpu_y = rand_helper(ytype if ytype is not None else xtype)
3250            y = cpu_y.detach().clone().to('mps')
3251
3252            cpu_z = rand_helper(ztype if ztype is not None else xtype)
3253            z = cpu_z.detach().clone().to('mps')
3254
3255            y = torch.addcmul(x, y, z, value=value)
3256            ref_y = torch.addcmul(cpu_x, cpu_y, cpu_z, value=value)
3257
3258            self.assertEqual(y, ref_y)
3259
3260        helper((2, 3, 4, 5), 0.1)
3261        helper((2, 8, 4, 5), 0.1)
3262        helper((2, 3, 4, 5), 0.2)
3263        helper((2, 8, 4, 5), 0.2)
3264        # Integral types
3265        helper((2, 2), 1.0, xtype=torch.int32)
3266        helper((2, 2), 2.0, xtype=torch.int16)
3267
3268        # Mixed types
3269        helper((2, 2), 1.0, xtype=torch.float16, ytype=torch.float32)
3270        helper((3, 2), 1.0, ytype=torch.float16)
3271        helper((2, 3), 1.0, ztype=torch.float16)
3272        helper((2, 2), 1.0, xtype=torch.int32, ytype=torch.int16, ztype=torch.uint8)
3273        helper((2, 2), 1.0, ytype=torch.int16, ztype=torch.uint8)
3274
3275    # Test addcdiv
3276    def test_addcdiv(self):
3277        def helper(shape, value):
3278            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3279            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3280            # clamp to avoid division by 0
3281            cpu_z = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False).clamp_min_(0.1)
3282            cpu_out = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
3283
3284            mps_x = cpu_x.detach().clone().to('mps')
3285            mps_y = cpu_y.detach().clone().to('mps')
3286            mps_z = cpu_z.detach().clone().to('mps')
3287            mps_out = cpu_out.detach().clone().to('mps')
3288
3289            result_div_mps = torch.addcdiv(mps_x, mps_y, mps_z, value=value)
3290            result_div_cpu = torch.addcdiv(cpu_x, cpu_y, cpu_z, value=value)
3291            self.assertEqual(result_div_mps, result_div_cpu)
3292            # test .out variant
3293            self.assertEqual(torch.addcdiv(mps_x, mps_y, mps_z, out=mps_out, value=value), result_div_cpu)
3294
3295        helper((2, 3, 4, 5), 0.1)
3296        helper((2, 8, 4, 5), 0.2)
3297        helper((2, 3, 4, 5), 1.0)  # value of 1 should be ignored internally
3298
3299    def test_addcdiv_transpose(self):
3300        # Regression test for issue https://github.com/pytorch/pytorch/issues/118115
3301        # Testing continuity of all input tensors
3302
3303        def helper(shape, value):
3304            shape_t = shape[::-1]
3305            for i in range(2):
3306                for j in range(2):
3307                    for k in range(2):
3308                        x = torch.rand(shape, device="cpu") if i == 0 else torch.rand(shape_t, device="cpu").t()
3309                        y = torch.rand(shape, device="cpu") if j == 0 else torch.rand(shape_t, device="cpu").t()
3310                        z = torch.rand(shape, device="cpu") if k == 0 else torch.rand(shape_t, device="cpu").t()
3311
3312                        x_mps = x.detach().clone().to(device="mps")
3313                        y_mps = y.detach().clone().to(device="mps")
3314                        z_mps = z.detach().clone().to(device="mps")
3315
3316                        result_cpu = x.addcdiv_(y, z, value=value)
3317                        result_mps = x_mps.addcdiv(y_mps, z_mps, value=value)
3318                        result_mps_out = result_cpu.detach().clone().to('mps')
3319                        torch.addcdiv(x_mps, y_mps, z_mps, out=result_mps_out, value=value)
3320
3321                        self.assertEqual(result_cpu, result_mps)
3322                        self.assertEqual(result_cpu, result_mps_out)
3323
3324        helper((2, 3), 1.0)
3325        helper((2, 3), 0.2)
3326        helper((100, 300), 1.0)
3327        helper((100, 300), 0.2)
3328
3329    def test_buffer_size_match(self):
3330        # this test shouldn't cause any crash
3331        size = 16
3332        cpu_A = torch.rand(size, device='cpu')
3333        cpu_F = torch.rand(size, size, size, device='cpu')
3334
3335        mps_A = cpu_A.to('mps')
3336        mps_F = cpu_F.to('mps')
3337        self.assertEqual(cpu_A @ cpu_F, mps_A @ mps_F)
3338
3339    def test_transpose_inplace(self):
3340        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3341        cpu_x = torch.tensor(values, device='cpu')
3342        mps_x = torch.tensor(values, device='mps')
3343
3344        cpu_x.transpose_(0, 1)
3345        mps_x.transpose_(0, 1)
3346        self.assertEqual(cpu_x, mps_x.to('cpu'))
3347
3348    def test_expand_cpu_to_mps_copy(self):
3349        # https://github.com/pytorch/pytorch/issues/78642
3350
3351        x = torch.tensor(1).expand([10]).to("mps")
3352        x_cpu = torch.tensor(1).expand([10])
3353
3354        self.assertEqual(x_cpu, x.cpu())
3355
3356    def test_cpu_to_strided_mps_copy(self):
3357        # https://github.com/pytorch/pytorch/issues/86975
3358
3359        a1 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3360        b1 = torch.Tensor([-1, -1])
3361        a1[1:, 1] = b1
3362
3363        a2 = torch.Tensor([[1, 2], [3, 4], [5, 6]]).to(torch.device("mps"))
3364        b2 = torch.Tensor([-1, -1]).to(torch.device("mps"))
3365        a2[1:, 1] = b2
3366
3367        self.assertEqual(a1, a2)
3368
3369    def test_view_slice_reshape(self):
3370        x = torch.randn([1, 4, 4], device="mps")
3371        y = x[0, :1, 1:]
3372
3373        x_cpu = x.to("cpu")
3374        y_cpu = x_cpu[0, :1, 1:]
3375
3376        r = y + 1
3377        r_cpu = y_cpu + 1
3378        self.assertEqual(r, r_cpu)
3379
3380    def test_slice_reshape(self):
3381        x = torch.randn([1, 6, 4, 2], dtype=torch.float, device="mps")
3382        x_cpu = x.detach().clone().to("cpu")
3383
3384        x = x[:, 3:].view(2, 3, 4, 1)
3385        x_cpu = x_cpu[:, 3:].view(2, 3, 4, 1)
3386        self.assertEqual(x, x_cpu)
3387
3388        x = x + 2
3389        x_cpu = x_cpu + 2
3390        self.assertEqual(x, x_cpu)
3391
3392    def test_reshape_storage_offset(self):
3393        # https://github.com/pytorch/pytorch/issues/95883
3394        B = 4
3395        T = 1
3396
3397        lin_cpu = nn.Linear(10, 256)
3398        lin_mps = nn.Linear(10, 256, device="mps")
3399
3400        # Use the same weights and bias as the ones from the cpu
3401        lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
3402        lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
3403
3404        x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
3405        x_cpu = x_mps.detach().clone().cpu().requires_grad_()
3406        x_mps = lin_mps(x_mps)
3407        x_cpu = lin_cpu(x_cpu)
3408
3409        self.assertEqual(x_mps.shape, (B, T, 256))
3410        self.assertEqual(x_cpu.shape, (B, T, 256))
3411
3412        cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
3413        cls_token_cpu = cls_token_mps.detach().clone().cpu()
3414        x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
3415        x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
3416
3417        x_mps = x_mps.transpose(0, 1)
3418        x_cpu = x_cpu.transpose(0, 1)
3419
3420        target_mps = torch.rand_like(x_mps)
3421        target_cpu = target_mps.detach().clone().cpu()
3422        loss_mps = F.mse_loss(x_mps, target_mps)
3423        loss_cpu = F.mse_loss(x_cpu, target_cpu)
3424        self.assertEqual(loss_mps, loss_cpu)
3425
3426        loss_mps.backward()
3427        loss_cpu.backward()
3428        self.assertEqual(x_mps.grad, x_cpu.grad)
3429
3430    def test_stack_storage_offset(self):
3431        # https://github.com/pytorch/pytorch/issues/87856
3432        x_cpu = torch.tensor([[1, 2]])
3433        x_mps = x_cpu.detach().clone().to("mps")
3434
3435        y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
3436        y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
3437
3438        self.assertEqual(y_cpu, y_mps)
3439
3440        t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3441        t_cpu = t_mps.detach().cpu().detach()
3442
3443        x_mps = t_mps[2:]
3444        y_mps = t_mps[:2]
3445
3446        x_cpu = t_cpu[2:]
3447        y_cpu = t_cpu[:2]
3448
3449        res_mps = torch.stack((y_mps, x_mps), dim=-1)
3450        res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
3451
3452        self.assertEqual(res_mps, res_cpu)
3453
3454    def test_unsafe_chunk(self):
3455        # https://github.com/pytorch/pytorch/issues/91065
3456        a = torch.rand(5, dtype=torch.float32, device="cpu")
3457        ret = a.unsafe_chunk(4, 0)
3458        y = ret[0] * ret[2]
3459        a_mps = a.to("mps")
3460        ret_mps = a_mps.unsafe_chunk(4, 0)
3461        y_mps = ret_mps[0] * ret_mps[2]
3462        self.assertEqual(y, y_mps)
3463
3464    def test_slice_casting(self):
3465        # generate random binary numbers
3466        cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
3467        mps_in = cpu_in.detach().clone().to("mps")
3468        # check copy_cast(unit8 -> bool) on tensors with storage offset
3469        cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
3470        mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
3471        self.assertEqual(cpu_out, mps_out)
3472
3473    def test_slice_reshape_contg_view(self):
3474        import torch
3475
3476        x_mps = torch.randn(1, 4800, 2, device="mps")
3477        x_cpu = x_mps.detach().clone().cpu()
3478
3479        r_mps = x_mps + 2
3480        r_cpu = x_cpu + 2
3481
3482        self.assertEqual(r_mps, r_cpu)
3483
3484    def test_contiguous_slice_2d(self):
3485        def helper(shape):
3486            for i in range(0, shape[0]):
3487                for j in range(0, shape[1]):
3488                    t_mps = torch.randn(shape, device="mps")
3489                    t_cpu = t_mps.detach().clone().cpu()
3490
3491                    y_mps = t_mps[i:, :j]
3492                    y_cpu = t_cpu[i:, :j]
3493                    self.assertEqual(y_mps + 1, y_cpu + 1)
3494
3495                    y_mps = t_mps[i:, j]
3496                    y_cpu = t_cpu[i:, j]
3497                    self.assertEqual(y_mps + 1, y_cpu + 1)
3498
3499                    y_mps = t_mps[i, :j]
3500                    y_cpu = t_cpu[i, :j]
3501                    self.assertEqual(y_mps + 1, y_cpu + 1)
3502
3503                    y_mps = t_mps[:i, :j]
3504                    y_cpu = t_cpu[:i, :j]
3505                    self.assertEqual(y_mps + 1, y_cpu + 1)
3506
3507                    y_mps = t_mps[:i, j]
3508                    y_cpu = t_cpu[:i, j]
3509                    self.assertEqual(y_mps + 1, y_cpu + 1)
3510
3511                    y_mps = t_mps[:i, j:]
3512                    y_cpu = t_cpu[:i, j:]
3513                    self.assertEqual(y_mps + 1, y_cpu + 1)
3514
3515        l = []
3516        for N in range(1, 3):
3517            l.append(N)
3518            for C in range(1, 3):
3519                l.append(C)
3520                helper(l)
3521                for D in range(1, 3):
3522                    l.append(D)
3523                    helper(l)
3524                    for H in range(1, 3):
3525                        l.append(H)
3526                        helper(l)
3527                        for W in range(1, 3):
3528                            l.append(W)
3529                            helper(l)
3530                            l.pop()
3531                        l.pop()
3532                    l.pop()
3533                l.pop()
3534            l.pop()
3535
3536        helper([9, 15, 4])
3537        helper([9, 3, 2])
3538        helper([3, 4, 18, 22])
3539        helper([3, 4, 18, 22, 150])
3540
3541    def test_contiguous_slice_3d(self):
3542        x = torch.randn(2, 3, 3, device="mps")
3543        x_cpu = x.detach().clone().cpu()
3544        x = x[:1]
3545        x_cpu = x_cpu[:1]
3546        out = x[:, 0:1, 0:1] * x[:, 1:2, 1:2]
3547        out_cpu = x_cpu[:, 0:1, 0:1] * x_cpu[:, 1:2, 1:2]
3548        self.assertEqual(out, out_cpu)
3549
3550    def test_view_slice(self):
3551        # https://github.com/pytorch/pytorch/issues/83995
3552        NUM_SAMPLES = 60
3553        s = (0, 1)
3554
3555        X = torch.rand(8000, 3, dtype=torch.float32, device='cpu')
3556        X_mps = X.detach().clone().to("cpu")
3557
3558        idx = torch.randint(0, X.shape[0], (1,)).repeat(len(s))
3559        pts = torch.randint(0, X.shape[0], (NUM_SAMPLES, X.shape[1]))
3560        idx_mps = idx.to("mps")
3561        pts_mps = pts.to("mps")
3562        pts[:, s] = idx
3563        pts_mps[:, s] = idx_mps
3564
3565        actual_pts = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float)
3566        actual_pts_mps = torch.zeros(NUM_SAMPLES, X.shape[1], dtype=torch.float, device="mps")
3567
3568        for i in range(NUM_SAMPLES):
3569            for j in range(X.shape[1]):
3570                actual_pts_mps[i, j] = X_mps[pts_mps[i, j], j]
3571                actual_pts[i, j] = X[pts[i, j], j]
3572                self.assertEqual(actual_pts[i, j], actual_pts_mps[i, j])
3573
3574    def test_slice_scatter(self):
3575        shape = (4, 4)
3576        tensor = torch.randint(10, shape, device="mps")
3577        tensor_before = tensor.clone()
3578        torch.empty(shape[0], shape[1] * 2, device="mps")[:, ::2].copy_(tensor)
3579        torch.testing.assert_close(tensor, tensor_before)
3580
3581    def test_slice(self):
3582        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
3583        cpu_x = torch.tensor(values, device='cpu')
3584        mps_x = (torch.tensor(values, device='mps', dtype=torch.float))
3585
3586        cpu_slice1 = cpu_x[:2, :]
3587        mps_slice1 = mps_x[:2, :]
3588        self.assertEqual(cpu_slice1, mps_slice1)
3589
3590        cpu_slice2 = cpu_x[:, :1]
3591        mps_slice2 = mps_x[:, :1]
3592        self.assertEqual(cpu_slice2, mps_slice2)
3593
3594        cpu_slice3 = cpu_x[1:2, :]
3595        mps_slice3 = mps_x[1:2, :]
3596        self.assertEqual(cpu_slice3, mps_slice3.to('cpu'))
3597
3598        cpu_slice4 = cpu_x[1, :]
3599        mps_slice4 = mps_x[1, :].to('cpu')
3600        self.assertEqual(cpu_slice4, mps_slice4)
3601
3602    @parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16])
3603    def test_slice_view_api(self, torch_type: torch.dtype):
3604
3605        def helper(x_tensor, y_func, z_func, r_func=None):
3606            x_mps = x_tensor.detach().clone().to("mps")
3607
3608            y = y_func(x_tensor)
3609            y_mps = y_func(x_mps)
3610            self.assertEqual(y, y_mps)
3611
3612            z = z_func(y)
3613            z_mps = z_func(y_mps)
3614            self.assertEqual(z, z_mps)
3615            self.assertEqual(z.storage_offset(), z_mps.storage_offset())
3616
3617            if r_func:
3618                r = r_func(z)
3619                r_mps = r_func(z_mps)
3620                self.assertEqual(r, r_mps)
3621
3622        # Skip bfloat16 before MacOS15
3623        if not (product_version < 15.0 and torch_type == torch.bfloat16):
3624            # Tests for previously encountered MPS bugs
3625            helper(
3626                torch.randn(4, 4, dtype=torch_type),
3627                lambda x: x[1],
3628                lambda y: y.reshape(2, 2),
3629                lambda z: z + 1
3630            )
3631            helper(
3632                torch.randn(2, 4, dtype=torch_type),
3633                lambda x: x[1],
3634                lambda y: y + torch.ones(4, device=y.device)
3635            )
3636            helper(
3637                torch.randn(4, 6, dtype=torch_type),
3638                lambda x: x[1],
3639                lambda y: y.reshape(3, 2).t(),
3640                lambda z: z + 1
3641            )
3642            helper(
3643                torch.arange(4, dtype=torch_type).resize(1, 2, 2),
3644                lambda x: x.permute(2, 0, 1),
3645                lambda y: y + 1
3646            )
3647            helper(
3648                torch.randn(4, 8, dtype=torch_type),
3649                lambda x: x.transpose(0, 1).reshape(-1),
3650                lambda y: y[:2],
3651                lambda z: z + 1
3652            )
3653            helper(
3654                torch.randn(1, dtype=torch_type),
3655                lambda x: x.expand(2, 3),
3656                lambda y: y + torch.ones(2, 3, device=y.device)
3657            )
3658
3659    def test_slice_reshape_contiguous(self):
3660        x = torch.randn(4, 4)
3661        x_mps = x.detach().clone().to("mps")
3662
3663        y = x[1]
3664        y_mps = x_mps[1]
3665        self.assertEqual(y, y_mps)
3666
3667        z = y.reshape(2, 2)
3668        z_mps = y_mps.reshape(2, 2)
3669        self.assertEqual(z, z_mps)
3670        self.assertEqual(z.storage_offset(), z_mps.storage_offset())
3671
3672    def test_scalar_from_slice_unary(self):
3673        # https://github.com/pytorch/pytorch/issues/82543
3674        tensor_list = torch.tensor([1.0, 1.2], device="mps")
3675
3676        for scalar in tensor_list:
3677            r_mps = torch.ceil(scalar)
3678            r_cpu = torch.ceil(scalar.to("cpu"))
3679            self.assertEqual(r_mps.cpu(), r_cpu)
3680
3681    def test_scalar_from_slice_binary(self):
3682        # https://github.com/pytorch/pytorch/issues/82543
3683        def helper(binary_op):
3684            tensor_list = torch.tensor([1.0, 1.2, 2.5, 1.0], device="mps")
3685
3686            for scalar in tensor_list:
3687                r_mps = binary_op(scalar, 1.0)
3688                r_cpu = binary_op(scalar.cpu(), 1.0)
3689                self.assertEqual(r_mps.cpu(), r_cpu)
3690        helper(torch.sub)
3691        helper(torch.add)
3692        helper(torch.not_equal)
3693        helper(torch.eq)
3694
3695    def test_slice_contiguous_view(self):
3696        # https://github.com/pytorch/pytorch/issues/77750
3697
3698        def helper(operator):
3699            t_mps = torch.tensor([1, 2, 3, 4], device="mps")
3700            t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")
3701
3702            # contiguous view
3703            x_mps = t_mps[2:]  # 3, 4
3704            y_mps = t_mps[:2]  # 1, 2
3705
3706            x_cpu = t_cpu[2:]
3707            y_cpu = t_cpu[:2]
3708
3709            res_mps = res_cpu = None
3710            if operator == "<=":
3711                res_mps = x_mps <= y_mps
3712                res_cpu = x_cpu <= y_cpu
3713            elif operator == "<":
3714                res_mps = x_mps < y_mps
3715                res_cpu = x_cpu < y_cpu
3716            elif operator == ">=":
3717                res_mps = x_mps >= y_mps
3718                res_cpu = x_cpu >= y_cpu
3719            elif operator == ">":
3720                res_mps = x_mps >= y_mps
3721                res_cpu = x_cpu >= y_cpu
3722            elif operator == "==":
3723                res_mps = x_mps == y_mps
3724                res_cpu = x_cpu == y_cpu
3725            elif operator == "!=":
3726                res_mps = x_mps != y_mps
3727                res_cpu = x_cpu != y_cpu
3728            elif operator == "stack":
3729                res_mps = torch.stack((y_mps, x_mps), dim=-1)
3730                res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
3731
3732            self.assertEqual(res_mps, res_cpu)
3733
3734        for op in ["<=", "<", ">=", ">", "==", "!=", "stack"]:
3735            helper(op)
3736
3737    def test_slice_of_slice(self):
3738        x = torch.tensor([0.5, 0.5], device="cpu")
3739        x_mps = torch.tensor([0.5, 0.5], device="mps")
3740
3741        tensor = x[1][None]
3742        tensor_mps = x_mps[1][None]
3743
3744        res = tensor.ne(0)
3745        res_mps = tensor_mps.ne(0)
3746
3747        self.assertEqual(res, res_mps)
3748
3749    def test_index_storage_offset(self):
3750        # https://github.com/pytorch/pytorch/issues/78107
3751
3752        a = torch.tensor([8.2670e-01, -1.0293e+00])
3753        b_cpu = a[0]
3754        c_cpu = a[1]
3755
3756        # both 'b' and 'c' are views of 'a'
3757        # 'b' has a storage offset of 0, while 'c' has a storage offset of 1
3758        # when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
3759        # otherwise it ends with same value as 'b'
3760        b = b_cpu.to('mps')
3761        c = c_cpu.to('mps')
3762
3763        res_mps = b > c
3764        res_cpu = b_cpu > c_cpu
3765        self.assertEqual(res_mps, res_cpu)
3766
3767        res_mps = c > b
3768        res_cpu = c_cpu > b_cpu
3769        self.assertEqual(res_mps, res_cpu)
3770
3771    def test_flatten(self):
3772        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3773        cpu_x = torch.tensor(values, device='cpu')
3774        mps_x = torch.tensor(values, device='mps')
3775
3776        cpu_flatten1 = cpu_x.flatten()
3777        mps_flatten1 = mps_x.flatten().to('cpu')
3778        self.assertEqual(cpu_flatten1, mps_flatten1)
3779
3780        cpu_flatten2 = cpu_x.flatten(start_dim=1)
3781        mps_flatten2 = mps_x.flatten(start_dim=1).to('cpu')
3782        self.assertEqual(cpu_flatten2, mps_flatten2)
3783
3784        cpu_flatten3 = cpu_x.flatten(end_dim=1)
3785        mps_flatten3 = mps_x.flatten(end_dim=1).to('cpu')
3786        self.assertEqual(cpu_flatten3, mps_flatten3)
3787
3788    # Test repeat
3789    def test_repeat(self):
3790        def helper(shape, repeats):
3791
3792            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
3793            x = cpu_x.detach().clone().to('mps').requires_grad_()
3794
3795            y = x.repeat(repeats)
3796            ref_y = cpu_x.repeat(repeats)
3797
3798            cpu_grad = torch.randn(ref_y.shape)
3799            grad = cpu_grad.to('mps')
3800
3801            y.backward(gradient=grad)
3802            ref_y.backward(gradient=cpu_grad)
3803
3804            self.assertEqual(y, ref_y)
3805            self.assertEqual(x.grad, cpu_x.grad)
3806
3807        helper((2, 3, 4, 5), (2, 3, 4, 5))
3808        helper((2, 3, 4), (4, 3, 2, 5, 7, 2))
3809        helper((3, 4, 5), (2, 3, 4, 5))
3810        helper((3, 4, 5), (2, 2, 2))
3811
3812    def test_torch_repeat_interleave(self, device="mps"):
3813        y = torch.tensor([[1, 2], [3, 4]], device=device)
3814        # exercise single argument function signature
3815        temp = y.repeat_interleave(2)
3816        self.assertEqual(torch.Size([8]), temp.size())
3817
3818        for dtype in [torch.int, torch.long]:
3819            lengths = torch.tensor([1, 2], dtype=dtype, device="mps")
3820            output_size = torch.sum(lengths)
3821            a = torch.repeat_interleave(
3822                y,
3823                lengths,
3824                dim=0,
3825            )
3826            self.assertEqual(a.dtype, y.dtype)
3827            self.assertEqual(a.size(), torch.Size([3, 2]))
3828
3829            a_with_output = torch.repeat_interleave(
3830                y,
3831                lengths,
3832                dim=0,
3833                output_size=output_size,
3834            )
3835            self.assertEqual(a_with_output.dtype, y.dtype)
3836            self.assertEqual(a_with_output.size(), torch.Size([3, 2]))
3837
3838    def test_repeat_interleave(self, device="mps"):
3839        x = torch.tensor([0, 1, 2, 3], device=device)
3840        expected = torch.tensor([1, 2, 2, 3, 3, 3], device=device)
3841        # Prior to macos 13.3, input of dtype=torch.int64 returns dtype=torch.int32
3842        self.assertEqual(torch.repeat_interleave(x), expected, exact_dtype=product_version >= 13.3)
3843
3844        with self.assertRaises(RuntimeError):
3845            torch.repeat_interleave(torch.arange(4, device=device).reshape(2, 2))
3846
3847        with self.assertRaises(RuntimeError):
3848            torch.repeat_interleave(torch.arange(4.0, device=device))
3849
3850        with self.assertRaises(RuntimeError):
3851            torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4], device=device))
3852
3853        y = torch.tensor([[1, 2], [3, 4]], device=device)
3854
3855        y1_v1 = torch.repeat_interleave(y, 2)
3856        y1_v2 = torch.repeat_interleave(y, torch.tensor(2, device=device))
3857        y1_v3 = torch.repeat_interleave(y, torch.tensor([2], device=device))
3858        y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4], device=device)
3859        self.assertEqual(y1_v1, y1_expect)
3860        self.assertEqual(y1_v2, y1_expect)
3861        self.assertEqual(y1_v3, y1_expect)
3862
3863        y2 = torch.repeat_interleave(y, 3, dim=1)
3864        y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2],
3865                                  [3, 3, 3, 4, 4, 4]], device=device)
3866        self.assertEqual(y2, y2_expect)
3867
3868        y3 = torch.repeat_interleave(y, torch.tensor([1, 2], device=device), dim=0)
3869        y3_expect = torch.tensor([[1, 2],
3870                                  [3, 4],
3871                                  [3, 4]], device=device)
3872        self.assertEqual(y3, y3_expect)
3873
3874        with self.assertRaises(RuntimeError):
3875            torch.repeat_interleave(y, torch.tensor([1, 2, 3], device=device), dim=0)
3876
3877        with self.assertRaises(RuntimeError):
3878            torch.repeat_interleave(y, torch.arange(9, device=device).reshape(3, 3), dim=0)
3879
3880        # test zero sized dimension
3881        x = torch.zeros((5, 0), device=device)
3882        y = torch.repeat_interleave(x, repeats=3, dim=1)
3883        self.assertEqual(y, x.new_zeros(5, 0, device=device))
3884
3885        x = torch.tensor([], dtype=torch.int64, device=device)
3886        y = torch.repeat_interleave(x, x)
3887        self.assertEqual(y, x)
3888
3889    def test_repeat_interleave_simple(self):
3890        def helper(shape, dtype=torch.float32, num_repeats=torch.Tensor(), dim=None):
3891            x = torch.randn(shape, dtype=dtype, device="mps")
3892            x_cpu = x.detach().clone().cpu()
3893
3894            num_repeats_cpu = num_repeats.detach().clone().cpu()
3895
3896            repeats = torch.repeat_interleave(x, num_repeats, dim)
3897            repeats_cpu = torch.repeat_interleave(x_cpu, num_repeats_cpu, dim)
3898
3899            self.assertEqual(repeats, repeats_cpu)
3900        helper(shape=3, num_repeats=torch.tensor([100], device="mps"))
3901        helper(shape=(2, 2), num_repeats=torch.tensor([3, 3], device="mps"), dim=0)
3902        helper(shape=(10, 15, 8), num_repeats=torch.arange(10, device="mps"), dim=0)
3903        helper(shape=(10, 15, 8), num_repeats=torch.randint(0, 100, (15, ), device="mps"), dim=1)
3904        helper(shape=(10, 15, 30), num_repeats=torch.randint(0, 100, (30, ), device="mps"), dim=2)
3905
3906    def test_count_nonzero(self):
3907        def helper(dtype):
3908            n = [
3909                [[1, 0, 2], [3, 0, 2], [7, 9, -4]],
3910                [[0, 2, 3], [3, 2, 1], [2, 0, 0]],
3911            ]
3912            cpu_x = torch.tensor(n, dtype=dtype)
3913            mps_x = torch.tensor(n, dtype=dtype).to('mps')
3914
3915            # All non-zeros
3916            self.assertEqual(
3917                torch.count_nonzero(cpu_x),
3918                torch.count_nonzero(mps_x)
3919            )
3920
3921            # dim=1
3922            self.assertEqual(
3923                torch.count_nonzero(cpu_x, dim=1),
3924                torch.count_nonzero(mps_x, dim=1)
3925            )
3926
3927            # dim=(0, 1)
3928            self.assertEqual(
3929                torch.count_nonzero(cpu_x, dim=(0, 1)),
3930                torch.count_nonzero(mps_x, dim=(0, 1))
3931            )
3932        helper(torch.int32)
3933        helper(torch.int64)
3934        helper(torch.float16)
3935        helper(torch.float32)
3936
3937    def _test_module_empty_input(self, module, inp, check_size=True):
3938        inp.requires_grad_(True)
3939        out = module(inp)
3940        gO = torch.rand_like(out)
3941        out.backward(gO)
3942        if check_size:
3943            self.assertEqual(out.size(), inp.size())
3944        for p in module.parameters():
3945            if p.requires_grad:
3946                self.assertEqual(p.grad, torch.zeros_like(p.grad))
3947        self.assertEqual(inp.grad, torch.zeros_like(inp))
3948
3949    # Test dtype casting, with and without simultaneous device change
3950    def test_to(self):
3951        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
3952        cpu_x = torch.tensor(values, device='cpu')
3953        mps_x = torch.tensor(values, device='mps')
3954
3955        self.assertEqual(cpu_x.int(), mps_x.int().cpu())
3956        self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
3957        self.assertEqual(cpu_x.float(), mps_x.float().cpu())
3958
3959        self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
3960                         torch.tensor(1, dtype=torch.int32))
3961        self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
3962        self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
3963        self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
3964                         torch.tensor(1, dtype=torch.int32))
3965        self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
3966                         torch.tensor(1.0))
3967        self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
3968                         torch.tensor(4, dtype=torch.int32))
3969        self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
3970                         torch.tensor(4, dtype=torch.int32))
3971        self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
3972                         torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
3973        # Cast int8 and uint8 to float and compare results
3974        # See https://github.com/pytorch/pytorch/issues/80009 for more details
3975        cpu_byte = torch.tensor([60, 160, 20, 220], dtype=torch.uint8)
3976        cpu_char = torch.tensor([60, -60, 20, -120], dtype=torch.uint8)
3977        for x_cpu in [cpu_byte, cpu_char]:
3978            x_mps = x_cpu.to('mps')
3979            self.assertEqual(x_mps.to(torch.float32), x_cpu.to(torch.float32))
3980
3981
3982    def test_setitem_scalar(self) -> None:
3983        device = 'mps'
3984        for dtype in [torch.int32, torch.float32, torch.int64]:
3985            for i in range(3, 6):
3986                for j in range(3, 6):
3987                    t = torch.zeros(i, j, dtype=dtype, device=device)
3988                    self.assertEqual(t.sum(), 0)
3989                    t[1, 1] = 1
3990                    t[2, 1] = j
3991                    t[1, 2] = i
3992                    self.assertEqual(t[1, 1], 1)
3993                    self.assertEqual(t[1, 2], i)
3994                    self.assertEqual(t[2, 1], j)
3995                    self.assertEqual(t.sum(), 1 + i + j)
3996
3997    def test_stride_of_strides(self) -> None:
3998        x = torch.rand(32, 1, device='mps')
3999        y = x.as_strided(size=(32, 2), stride=(1, 0))
4000        # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
4001        # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
4002        z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
4003        self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
4004
4005    def test_type_casting(self):
4006        # https://github.com/pytorch/pytorch/issues/81567
4007        def helper(data, to_dtype):
4008            a_cpu = torch.tensor(data)
4009            a_mps = a_cpu.to(torch.device('mps'))
4010
4011            res_cpu = a_cpu.type(to_dtype)
4012            res_mps = a_mps.type(to_dtype)
4013            self.assertEqual(res_cpu, res_mps)
4014
4015        helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
4016        helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
4017        helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
4018        helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
4019        helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
4020        helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
4021        helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
4022
4023    def test_to_casting(self):
4024        # https://github.com/pytorch/pytorch/issues/81567
4025        def helper(data, to_dtype):
4026            a_cpu = torch.tensor(data)
4027            a_mps = a_cpu.to(torch.device('mps'))
4028
4029            res_cpu = a_cpu.to(to_dtype)
4030            res_mps = a_mps.to(to_dtype)
4031            self.assertEqual(res_cpu, res_mps)
4032
4033        helper([9.0, 3.0, 5.0, 4.0], torch.int64)
4034        helper([9.0, 3.0, 5.0, 4.0], torch.float)
4035        helper([9.0, 3.0, 5.0, 4.0], torch.int32)
4036        helper([9.0, 3.0, 5.0, 4.0], torch.short)
4037        helper([9.0, 3.0, 5.0, 4.0], torch.half)
4038        helper([9.0, 3.0, 5.0, 4.0], torch.int8)
4039        helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
4040
4041    def test_storage_offset_greater_than_src_nbytes(self):
4042        # https://github.com/pytorch/pytorch/issues/80844
4043        n_tensors = 100
4044        n_tensor_elems = 784
4045        elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
4046
4047        tensor_list = []
4048        for i in range(0, n_tensors - 1):
4049            # create a list of contiguous view tensors (view tensor created by the slice op)
4050            t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
4051            tensor_list.append(t)
4052
4053        for i in range(0, n_tensors - 1):
4054            t = tensor_list[i].view(1, n_tensor_elems)
4055            t_mps = t.to("mps")
4056            self.assertEqual(t, t_mps.cpu(), f"i={i}")
4057
4058    # See https://github.com/pytorch/pytorch/issues/82427
4059    # and https://github.com/pytorch/pytorch/issues/83692
4060    def test_full_bugs(self):
4061        # Test should not crash
4062        x = torch.full((3, 3), True, device='mps')
4063        # torch.full should work for uint8
4064        y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
4065        y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
4066        self.assertEqual(y_mps, y_cpu)
4067
4068    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
4069    # See https://github.com/pytorch/pytorch/issues/84995
4070    def test_div_bugs(self):
4071        for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
4072            if dtype != torch.int64:
4073                x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
4074                y = torch.div(x, 101, rounding_mode=mode)
4075                self.assertEqual(y.sum(), 0)
4076
4077    # See https://github.com/pytorch/pytorch/issues/82663
4078    def test_bool_expand(self):
4079        x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')
4080        y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
4081        self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
4082
4083    def test_int_expand(self):
4084        x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps')
4085        y = torch.tensor([0, 1], dtype=torch.int8, device='mps')
4086        self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
4087
4088    # Empty unary op should return tensor of the same size
4089    def test_empty_neg(self):
4090        x = torch.tensor([[]], device='mps')
4091        y = -x
4092        self.assertEqual(x, y)
4093
4094    def _test_unique_scalar_empty(self, dtype, device, f):
4095        # test scalar
4096        x = torch.tensor(0, dtype=dtype, device=device)
4097        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
4098        expected_unique = torch.tensor([0], dtype=dtype, device=device)
4099        expected_inverse = torch.tensor(0, device=device)
4100        expected_counts = torch.tensor([1], device=device)
4101        self.assertEqual(unique, expected_unique)
4102        self.assertEqual(inverse, expected_inverse)
4103        self.assertEqual(counts, expected_counts)
4104
4105        # test zero sized tensor
4106        x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
4107        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
4108        expected_unique = torch.tensor([], dtype=dtype, device=device)
4109        expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
4110        expected_counts = torch.tensor([], dtype=torch.long, device=device)
4111        self.assertEqual(unique, expected_unique)
4112        self.assertEqual(inverse, expected_inverse)
4113        self.assertEqual(counts, expected_counts)
4114
4115    def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
4116        def ensure_tuple(x):
4117            if isinstance(x, torch.Tensor):
4118                return (x,)
4119            return x
4120
4121        for return_inverse in [True, False]:
4122            for return_counts in [True, False]:
4123                # test with expected
4124                ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4125                self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4126                self.assertEqual(expected_unique, ret[0])
4127                if return_inverse:
4128                    self.assertEqual(expected_inverse, ret[1])
4129                if return_counts:
4130                    count_index = 1 + int(return_inverse)
4131                    self.assertEqual(expected_counts, ret[count_index])
4132
4133                # tests per-element unique on a higher rank tensor.
4134                y = x.view(additional_shape)
4135                y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
4136                self.assertEqual(expected_unique, y_unique)
4137                self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
4138                self.assertEqual(expected_counts, y_counts)
4139
4140    def test_unique_all_dtypes(self, device="mps"):
4141        def helper(dtype):
4142            def ensure_tuple(x):
4143                if isinstance(x, torch.Tensor):
4144                    return (x,)
4145                return x
4146
4147            if dtype is torch.bool:
4148                x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
4149                expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
4150                expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
4151                expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
4152            else:
4153                x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
4154                expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
4155                expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
4156                expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
4157
4158            # test sorted unique
4159            fs = (
4160                lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
4161                lambda x, **kwargs: x.unique(sorted=True, **kwargs),
4162            )
4163            x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
4164            xs = (x, x_sliced)
4165            for f, x in product(fs, xs):
4166                self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
4167                self._test_unique_scalar_empty(dtype, device, f)
4168
4169            # test unsorted unique
4170            fs = (
4171                lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
4172                lambda x, **kwargs: x.unique(sorted=False, **kwargs)
4173            )
4174            for f, x in product(fs, xs):
4175                self._test_unique_scalar_empty(dtype, device, f)
4176                for return_inverse, return_counts in product((True, False), repeat=2):
4177                    ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
4178                    self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
4179                    x_list = x.tolist()
4180                    x_unique_list = ret[0].tolist()
4181                    self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
4182                    if return_inverse:
4183                        x_inverse_list = ret[1].tolist()
4184                        for i, j in enumerate(x_inverse_list):
4185                            self.assertEqual(x_list[i], x_unique_list[j])
4186                    if return_counts:
4187                        count_index = 1 + int(return_inverse)
4188                        x_counts_list = ret[count_index].tolist()
4189                        for i, j in zip(x_unique_list, x_counts_list):
4190                            count = 0
4191                            for k in x_list:
4192                                if k == i:
4193                                    count += 1
4194                            self.assertEqual(j, count)
4195        [helper(dtype) for dtype in [torch.float32, torch.int64, torch.int32, torch.int16, torch.uint8]]
4196
4197    def test_unique(self):
4198        def helper(x, return_inverse, return_counts):
4199            cpu_x = x
4200            x = cpu_x.detach().clone().to('mps')
4201
4202            result = torch.unique(x, return_inverse=return_inverse, return_counts=return_counts)
4203            result_cpu = torch.unique(cpu_x, return_inverse=return_inverse, return_counts=return_counts)
4204
4205            self.assertEqual(result, result_cpu)
4206        helper(torch.tensor([1, 2, 4, 2, 1]), False, False)
4207        helper(torch.randint(3, (10, )), False, False)
4208        helper(torch.randint(3, (10, )), True, False)
4209        helper(torch.randint(3, (10, )), False, True)
4210        helper(torch.randint(3, (10, )), True, True)
4211        helper(torch.randint(3, (1, )), True, True)
4212        helper(torch.randint(3, (0, )), True, True)
4213        # Regression test for https://github.com/pytorch/pytorch/issues/104879
4214        x = torch.arange(2, device="mps")
4215        self.assertEqual(x.reshape(1, 1, 2).unique(), x)
4216
4217    def test_unique_consecutive(self):
4218        def helper(x, dim, return_inverse, return_counts):
4219            cpu_x = x
4220            x = cpu_x.detach().clone().to('mps')
4221
4222            result = torch.unique_consecutive(x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4223            result_cpu = torch.unique_consecutive(cpu_x, dim=dim, return_inverse=return_inverse, return_counts=return_counts)
4224
4225            self.assertEqual(result, result_cpu)
4226        helper(torch.tensor([1, 2, 4, 2, 1]), 0, False, False)
4227        helper(torch.randint(3, (10, )), 0, False, False)
4228        helper(torch.randint(3, (10, )), 0, True, False)
4229        helper(torch.randint(3, (10, )), 0, False, True)
4230        helper(torch.randint(3, (10, )), 0, True, True)
4231        helper(torch.randint(3, (10, )), 0, True, True)
4232        helper(torch.randint(3, (1, )), 0, True, True)
4233        helper(torch.randint(3, (0, )), 0, True, True)
4234
4235        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, False, False)
4236        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 0, True, True)
4237        helper(torch.randint(2, (20, 2)), 0, True, True)
4238        helper(torch.randint(2, (1, 2)), 0, True, True)
4239        helper(torch.randint(2, (0, 2)), 0, True, True)
4240
4241        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, False, False)
4242        helper(torch.tensor([[1, 1, 2, 3, 3, 2], [1, 1, 1, 2, 2, 1]]), 1, True, True)
4243        helper(torch.randint(2, (2, 20)), 1, True, True)
4244        helper(torch.randint(2, (2, 1)), 1, True, True)
4245        helper(torch.randint(2, (2, 0)), 1, True, True)
4246
4247    # See https://github.com/pytorch/pytorch/issues/85675
4248    def test_cat_non_contiguous(self):
4249        def rotate_subset(data, dim):
4250            x1 = data[:, :, :2, :]
4251            x2 = data[:, :, 2:, :]
4252            self.assertFalse(x1.is_contiguous())
4253            self.assertFalse(x2.is_contiguous())
4254            return torch.concat((x1, x2), dim=dim)
4255        for dtype in MPS_DTYPES:
4256            if dtype == torch.bool:
4257                continue
4258            data = torch.arange(48, dtype=dtype).reshape(1, 2, 4, 6)
4259            data = data.to(memory_format=torch.channels_last)
4260            mps_data = data.to("mps")
4261            self.assertEqual(data, mps_data)
4262            for dim in range(data.dim()):
4263                cpu_result = rotate_subset(data, dim)
4264                mps_result = rotate_subset(mps_data, dim)
4265                self.assertEqual(cpu_result, mps_result.to("cpu"))
4266                # TODO: enable memory format test
4267                # self.assertEqual(cpu_result.is_contiguous(), mps_result.is_contiguous())
4268
4269    # See https://github.com/pytorch/pytorch/issues/85967
4270    def test_from_numpy_non_contiguous(self):
4271        a = np.arange(9).reshape(3, 3)[:, :2]
4272        t_cpu = torch.tensor(a, device="cpu")
4273        t_mps = torch.tensor(a, device="mps")
4274        self.assertEqual(t_cpu, t_mps.to("cpu"))
4275
4276    # See https://github.com/pytorch/pytorch/issues/86954
4277    def test_copy_non_contiguous(self):
4278        x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
4279        self.assertFalse(x.is_contiguous())
4280        y = x.to('mps')
4281        self.assertFalse(y.is_contiguous())
4282        self.assertEqual(x, y.to('cpu'))
4283
4284        x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
4285        y = x.to('mps')
4286        self.assertEqual(x, y.to('cpu'))
4287
4288        x = torch.full((4, 4, 4, 4), 13, device="cpu")
4289        y = torch.full((4, 4, 4, 4), 13, device="mps")
4290        z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
4291        x.permute(3, 2, 1, 0)[1::, ::2] = z
4292        # As y is on MPS and z on CPU, this dispatches to a copy operator
4293        y.permute(3, 2, 1, 0)[1::, ::2] = z
4294        self.assertEqual(x, y.to('cpu'))
4295
4296    # See https://github.com/pytorch/pytorch/issues/95417
4297    def test_copy_storage_offset(self):
4298        x_cpu = torch.zeros(5, device="cpu", dtype=torch.float32)
4299        x_mps = torch.zeros(5, device="mps", dtype=torch.float32)
4300        update_cpu = torch.tensor([1, 1], device="cpu", dtype=torch.int64)
4301        update_mps = torch.tensor([1, 1], device="mps", dtype=torch.int64)
4302        x_cpu[2:4] = update_cpu
4303        x_mps[2:4] = update_mps  # implicit type casting and copy
4304        self.assertEqual(x_cpu, x_mps)
4305
4306        x_cpu[2:4] = update_mps  # implicit device moving and copy
4307        self.assertEqual(x_cpu, x_mps)
4308
4309    def test_copy_broadcasting(self):
4310        def helper(src_shape, dst_shape, src_dtype, dst_dtype):
4311            cpu_src = torch.randint(0, 127, src_shape).to(src_dtype)
4312            cpu_dst = torch.randint(0, 127, dst_shape).to(dst_dtype)
4313            cpu_result = cpu_dst.copy_(cpu_src)
4314            mps_src = cpu_src.to("mps")
4315            mps_dst = cpu_dst.to("mps")
4316            mps_result = mps_dst.copy_(mps_src)
4317            self.assertEqual(cpu_result, mps_result)
4318
4319        test_dtypes = [torch.float32, torch.int32, torch.int16, torch.int8]
4320
4321        for (src_dtype, dst_dtype) in itertools.product(test_dtypes, test_dtypes):
4322            helper((2, 1), (2, 3), src_dtype, dst_dtype)
4323            helper((2, 1), (2, 2), src_dtype, dst_dtype)
4324            helper((3, 1, 4, 1), (3, 4, 4, 5), src_dtype, dst_dtype)
4325            helper((3,), (2, 3), src_dtype, dst_dtype)
4326            helper((2,), (2, 2), src_dtype, dst_dtype)
4327            helper((4, 1, 5), (3, 4, 4, 5), src_dtype, dst_dtype)
4328            helper((4, 1, 5), (4, 0, 5), src_dtype, dst_dtype)
4329            helper((1, 5), (4, 0, 5), src_dtype, dst_dtype)
4330            helper((3, 1, 0), (3, 5, 0), src_dtype, dst_dtype)
4331            helper((0, 1, 0), (0, 5, 0), src_dtype, dst_dtype)
4332        # Regression test for https://github.com/pytorch/pytorch/issues/107867
4333        self.assertEqual(torch.tensor([[1]], device='mps').item(), 1.0)
4334
4335    # See https://github.com/pytorch/pytorch/pull/84742
4336    # and https://github.com/pytorch/pytorch/pull/78319
4337    def test_binops_dtype_precedence(self):
4338        # Test dtype precedence (casting order) in binary operations by comparing to CPU result
4339        # Example values for all dtypes supported on the MPS backend
4340        sample_vals = {
4341            torch.bool: [False, True],
4342            torch.int16: [-15, 0, 1, 10],
4343            torch.int32: [-376, 0, 1, 13],
4344            torch.int64: [-8, 0, 1, 77],
4345            torch.float16: [-234.5, 0.0, 1.0, 2.0],
4346            torch.float32: [-1.0, 0.0, 0.1, 111.99],
4347        }
4348        # Test all combinations of dtypes, operations, dimensionality
4349        for dtype1, dtype2, binop in itertools.product(
4350                sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']):
4351            # bool minus bool is generally unsupported, so skip
4352            if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool):
4353                continue
4354            full_shape = (10,)
4355            for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]):
4356                # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})')
4357                # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4358                #            (torch.tensor(val2, dtype=dtype2, device='mps')))
4359                # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4360                #            (torch.tensor(val2, dtype=dtype2, device='cpu')))
4361                self.assertEqual(
4362                    getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4363                           (torch.tensor(val2, dtype=dtype2, device='mps')),
4364                    getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4365                           (torch.tensor(val2, dtype=dtype2, device='cpu')))
4366                self.assertEqual(
4367                    getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4368                           (torch.tensor([val2], dtype=dtype2, device='mps')),
4369                    getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4370                           (torch.tensor([val2], dtype=dtype2, device='cpu')))
4371                self.assertEqual(
4372                    getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4373                           (torch.tensor([val2], dtype=dtype2, device='mps')),
4374                    getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4375                           (torch.tensor([val2], dtype=dtype2, device='cpu')))
4376                self.assertEqual(
4377                    getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop)
4378                           (torch.tensor(val2, dtype=dtype2, device='mps')),
4379                    getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop)
4380                           (torch.tensor(val2, dtype=dtype2, device='cpu')))
4381                # Test tensors created with torch.full
4382                x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps')
4383                y1 = torch.tensor(val2, dtype=dtype2, device='mps')
4384                x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu')
4385                y2 = torch.tensor(val2, dtype=dtype2, device='cpu')
4386                self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2))
4387                x3 = torch.tensor(val1, dtype=dtype1, device='mps')
4388                y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps')
4389                x4 = torch.tensor(val1, dtype=dtype1, device='cpu')
4390                y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu')
4391                self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4))
4392                self.assertEqual(
4393                    getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop)
4394                           (torch.full(full_shape, val2, dtype=dtype2, device='mps')),
4395                    getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop)
4396                           (torch.full(full_shape, val2, dtype=dtype2, device='cpu')))
4397
4398    def test_nansum(self):
4399        def helper(dtype, noncontiguous, dim):
4400            zero_cpu = torch.zeros((), dtype=dtype)
4401
4402            # Randomly scale the values
4403            scale = random.randint(10, 100)
4404            x_cpu: torch.Tensor = make_tensor(
4405                (5, 5), dtype=dtype, device='cpu',
4406                low=-scale, high=scale, noncontiguous=noncontiguous)
4407
4408            if dtype.is_floating_point:
4409                nan_mask_cpu = x_cpu < (0.2 * scale)
4410                x_no_nan_cpu = torch.where(nan_mask_cpu, zero_cpu, x_cpu)
4411                x_cpu[nan_mask_cpu] = np.nan
4412            else:
4413                x_no_nan_cpu = x_cpu
4414
4415            x_mps = x_cpu.to('mps')
4416            actual_out_mps = torch.empty(0, dtype=dtype, device='mps')
4417            expect_out_cpu = torch.empty(0, dtype=dtype)
4418            dim_kwargs = {"dim": dim} if dim is not None else {}
4419            expect = torch.sum(x_no_nan_cpu, **dim_kwargs)
4420
4421            actual_cpu = torch.nansum(x_cpu, **dim_kwargs)
4422            # Sanity check on CPU
4423            self.assertEqual(expect, actual_cpu)
4424
4425            # Test MPS
4426            actual_mps = torch.nansum(x_mps, **dim_kwargs)
4427            # Test out= variant
4428            torch.nansum(x_mps, out=actual_out_mps, **dim_kwargs)
4429            torch.nansum(x_cpu, out=expect_out_cpu, **dim_kwargs)
4430            self.assertEqual(expect, actual_mps)
4431            self.assertEqual(expect_out_cpu, actual_out_mps)
4432
4433        args = itertools.product(
4434            (torch.float16, torch.float32, torch.int32, torch.int64),   # dtype
4435            (True, False),                                              # noncontiguous
4436            (0, 1, None),                                               # dim
4437        )
4438
4439        for dtype, noncontiguous, dim in args:
4440            with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
4441                helper(dtype, noncontiguous, dim)
4442
4443    def test_cumsum_all_dtypes(self):
4444        def helper(dtype):
4445            t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4446            t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4447
4448            a = t.cumsum(0, dtype=dtype)
4449            a_cpu = t_cpu.cumsum(0, dtype=dtype)
4450
4451            self.assertEqual(a.cpu(), a_cpu)
4452        [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4453
4454        try:
4455            helper(torch.int64)
4456        except Exception as e:
4457            e_string = str(e)
4458            self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
4459                             " Support has been added in macOS 13.3")
4460
4461    def test_cumsum_bool(self):
4462        a = torch.ones(2**16, dtype=torch.bool)
4463        t_cpu = a.cumsum(0)
4464        t_mps = a.to("mps").cumsum(0)
4465
4466        self.assertEqual(t_cpu, t_mps)
4467
4468    def test_cumsum_minus_one_axis(self):
4469        def helper(dtype):
4470            # Test with axis -1
4471            cpu_x = None
4472            if dtype == torch.float32:
4473                cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4474            else:
4475                cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4476            x = cpu_x.detach().clone().to('mps')
4477
4478            cpu_y = cpu_x.cumsum(-1)
4479            y = x.cumsum(-1)
4480
4481            self.assertEqual(y, cpu_y)
4482
4483        [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
4484
4485    def test_cumprod_all_dtypes(self):
4486        def helper(dtype):
4487            t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
4488            t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
4489
4490            a = t.cumprod(0, dtype=dtype)
4491            a_cpu = t_cpu.cumprod(0, dtype=dtype)
4492
4493            self.assertEqual(a.cpu(), a_cpu)
4494        [helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]
4495
4496        try:
4497            helper(torch.int64)
4498        except Exception as e:
4499            e_string = str(e)
4500            self.assertEqual(e_string, "MPS does not support cumprod_out_mps op with int64 input."
4501                             + " Support has been added in macOS 13.3")
4502
4503    def test_cumprod_minus_one_axis(self):
4504        def helper(dtype):
4505            # Test with axis -1
4506            cpu_x = None
4507            if dtype == torch.float32:
4508                cpu_x = torch.randn(10, 3, device='cpu', dtype=torch.float32)
4509            else:
4510                cpu_x = torch.randint(0, 20, (10, 3), device='cpu', dtype=torch.float32)
4511            x = cpu_x.detach().clone().to('mps')
4512
4513            cpu_y = cpu_x.cumprod(-1)
4514            y = x.cumprod(-1)
4515
4516            self.assertEqual(y, cpu_y)
4517
4518        [helper(dtype) for dtype in [torch.float32, torch.int16, torch.int32, torch.uint8]]
4519
4520    def test_median_int16(self):
4521        def helper(shape, dtype):
4522            cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
4523            x = cpu_x.detach().clone().to('mps')
4524
4525            median_result = torch.median(x)
4526            median_result_cpu = torch.median(cpu_x)
4527            self.assertEqual(median_result, median_result_cpu)
4528
4529        helper((2, 8, 4, 5), torch.int16)
4530
4531    def test_activation_checkpoint_does_not_error(self):
4532        from torch.utils.checkpoint import checkpoint
4533
4534        for use_reentrant in (True, False):
4535            a = torch.tensor(1., device="mps", requires_grad=True)
4536
4537            def fn(x):
4538                return x.sin().cos().exp()
4539
4540            out = checkpoint(fn, a, use_reentrant=use_reentrant)
4541            out.backward()
4542
4543    def test_as_strided(self):
4544        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4545        values_1 = [[1.0, 1.0], [1.0, 1.0]]
4546        cpu_x = torch.tensor(values, device='cpu')
4547        ones1 = torch.tensor(values_1, device='mps')
4548        x = cpu_x.detach().clone().to('mps').requires_grad_()
4549        strided_cpu = torch.as_strided(cpu_x, (2, 2), (1, 2))
4550        strided_mps = torch.as_strided(x, (2, 2), (1, 2))
4551        self.assertEqual(strided_mps, strided_cpu)
4552        strided_cpu_out = strided_cpu + ones1.to('cpu')
4553        strided_mps_out = strided_mps + ones1
4554        self.assertEqual(strided_cpu_out, strided_mps_out)
4555
4556        # test with storage offsets
4557        cpu_x = torch.rand(3, 3, device='cpu')
4558        mps_x = cpu_x.to('mps')
4559        strided_cpu1 = torch.as_strided(cpu_x, (2, 2), (1, 2), 0)
4560        strided_mps1 = torch.as_strided(mps_x, (2, 2), (1, 2), 0)
4561        strided_cpu2 = torch.as_strided(cpu_x, (2, 2), (1, 2), 1)
4562        strided_mps2 = torch.as_strided(mps_x, (2, 2), (1, 2), 1)
4563        strided_cpu_out = strided_cpu1 - strided_cpu2
4564        strided_mps_out = strided_mps1 - strided_mps2
4565        self.assertEqual(strided_cpu_out, strided_mps_out)
4566
4567    def test_unfold(self):
4568        x = torch.arange(1., 8)
4569        x_mps = torch.arange(1., 8, device="mps")
4570
4571        y = x.unfold(0, 2, 1)
4572        y_mps = x_mps.unfold(0, 2, 1)
4573
4574        self.assertEqual(y, y_mps)
4575
4576    def test_unfold_all_devices_and_dtypes(self):
4577        supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
4578        for dt in supported_dtypes:
4579            x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
4580            self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
4581
4582    def test_unfold_scalars(self):
4583        x = torch.tensor(0.5, device="mps")
4584        # unfold on a 0-dimensional tensor should always return a 1-d dimensional
4585        # tensor of shape [size] (i.e., the second parameter to unfold)
4586
4587        self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
4588        self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
4589        self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
4590
4591    def test_bincount_simple(self):
4592        input = torch.randint(0, 8, (5,), dtype=torch.int32, device="mps")
4593        input_cpu = input.to("cpu")
4594        weights = torch.linspace(0, 1, steps=5, device="mps", dtype=torch.float32)
4595        weights_cpu = weights.to("cpu")
4596
4597        x = torch.bincount(input)
4598        x_cpu = torch.bincount(input_cpu)
4599        self.assertEqual(x, x_cpu)
4600
4601        y = input.bincount(weights)
4602        y_cpu = input_cpu.bincount(weights_cpu)
4603        self.assertEqual(y, y_cpu)
4604
4605    def test_bincount_reduction(self):
4606        device = "mps"
4607        # negative input throws
4608        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4609            torch.bincount(torch.tensor([1, -1], device=device, dtype=torch.int32))
4610        # n-d input, with n > 1 throws
4611        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
4612            torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
4613        # minlength < 0 throws
4614        with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
4615            torch.bincount(torch.tensor([1, 3], device=device),
4616                           torch.tensor([.2, .2], device=device),
4617                           minlength=-1)
4618        # n-d weights, with n > 1 throws
4619        with self.assertRaisesRegex(RuntimeError, '1-d'):
4620            torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4621                           torch.tensor([[1., 0.3], [1., 0.3]], device=device, dtype=torch.float))
4622        # input and weights dim mismatch
4623        with self.assertRaisesRegex(RuntimeError, 'same length'):
4624            torch.bincount(torch.tensor([1, 0], device=device, dtype=torch.int32),
4625                           torch.tensor([1., 0.3, 0.5], device=device, dtype=torch.float))
4626        # 1-d input with no elements and default minlength
4627        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
4628                         torch.zeros(0, dtype=torch.long, device=device))
4629        # 1-d input with no elements and specified minlength
4630        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
4631                         torch.zeros(10, dtype=torch.long, device=device))
4632
4633        # test tensor method without weights
4634        long_counts = torch.tensor(
4635            [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
4636        self.assertEqual(
4637            torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
4638            long_counts)
4639        # test avoiding overflow for uint8 (#76979)
4640        count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
4641        count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
4642        self.assertEqual(count_uint8, count_int16)
4643        # test minlength functionality
4644        int_counts = torch.bincount(
4645            torch.tensor([1, 1, 1, 1], device=device, dtype=torch.int32), minlength=5)
4646        self.assertEqual(
4647            torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
4648            int_counts)
4649        # test weights
4650        byte_counts = torch.bincount(
4651            torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4652            torch.tensor([.1, .2, .3, .4, .5], device=device))
4653        self.assertEqual(
4654            torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
4655        byte_counts = torch.bincount(
4656            torch.tensor([0, 1, 1, 1, 4], device=device, dtype=torch.int32),
4657            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
4658        self.assertEqual(
4659            torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.int32), byte_counts)
4660        # test non-contiguous inputs and weights
4661        inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device, dtype=torch.int32)
4662        weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
4663        for i in [0, 1]:
4664            assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
4665            assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
4666        # inputs are non-contiguous but weights are contiguous
4667        self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
4668        # inputs and weights are non-contiguous
4669        self.assertEqual(
4670            inputs[:, 1].bincount(weights[:, 1]),
4671            torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4672        # weights are non-contiguous but inputs are contiguous
4673        self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
4674                         torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
4675
4676        # test bincount on non-contiguous slices
4677        all0s = torch.zeros((32, 2), dtype=torch.int32, device=device)
4678        self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
4679
4680        all1s = torch.ones((32, 2), dtype=torch.int32, device=device)
4681        self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
4682
4683        # test large number of bins - global memory use
4684        big_exp = torch.zeros(100, device=device)
4685        big_exp[-1] = 50.0
4686        big_w = torch.tensor([.5] * 100, device=device)
4687        big_out = torch.tensor([99] * 100, device=device, dtype=torch.int32).bincount(big_w)
4688        self.assertEqual(big_exp, big_out)
4689        # test large input size
4690        big_exp = torch.zeros(2, device=device, dtype=torch.int64)
4691        big_exp[1] = 10
4692        big_out = torch.ones(10, dtype=torch.int8, device=device).bincount()
4693        self.assertEqual(big_exp, big_out)
4694
4695    def test_bincount(self):
4696        device = "mps"
4697        input_size = (5000,)
4698        w = torch.randn(input_size, dtype=torch.float, device=device)
4699        w_cpu = w.cpu()
4700
4701        t = torch.randint(50, input_size, dtype=torch.int8, device=device)
4702        self.assertEqual(t.cpu().bincount(), t.bincount())
4703        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4704
4705        t = torch.randint(500, input_size, dtype=torch.int32, device=device)
4706        self.assertEqual(t.cpu().bincount(), t.bincount())
4707        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4708
4709        t = torch.randint(2000, input_size, dtype=torch.int32, device=device)
4710        self.assertEqual(t.cpu().bincount(), t.bincount())
4711        self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
4712
4713        t = torch.zeros([10], dtype=torch.int32, device=device)
4714        t[0] = 35488
4715        counted = t.bincount(minlength=65536)
4716        self.assertEqual(torch.sum(counted), 10)
4717
4718    def test_sum_backward(self):
4719        def helper(n, c):
4720            values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
4721            cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
4722            x = cpu_x.detach().clone().to('mps').requires_grad_()
4723
4724            all_sum = torch.sum(x)
4725            all_sum_cpu = torch.sum(cpu_x)
4726
4727            all_sum.backward()
4728            all_sum_cpu.backward()
4729            self.assertEqual(all_sum, all_sum_cpu)
4730            self.assertEqual(x.grad, cpu_x.grad)
4731
4732        helper(3, 3)
4733
4734    # L1 loss
4735    def test_l1_loss(self):
4736        def helper(shape, reduction):
4737            # create the criterion
4738            loss = torch.nn.L1Loss(reduction=reduction)
4739
4740            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4741            targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4742            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4743            targetMPS = targetCPU.detach().clone().to('mps')
4744
4745            # forward pass
4746            outputCPU = loss(inputCPU, targetCPU)
4747            outputMPS = loss(inputMPS, targetMPS)
4748            self.assertEqual(outputCPU, outputMPS)
4749
4750            # backward pass
4751            if reduction != 'none':
4752                # chose 2 just to make the grad_output > 1 in backward pass
4753                outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4754                outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4755                self.assertEqual(inputCPU.grad, inputMPS.grad)
4756
4757        helper([8, 5, 4], 'none')
4758        helper([7, 5, 2, 4], 'sum')
4759        # verify if changes in shape would cause cached graph lookup problems
4760        helper([7, 5, 2, 4, 6], 'sum')
4761        helper([8, 4, 5, 7, 6], 'mean')
4762
4763    # Mean Squared Error
4764    def test_mse_loss(self):
4765        def helper(shape, reduction):
4766            # create the criterion
4767            loss = torch.nn.MSELoss(reduction=reduction)
4768
4769            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
4770            targetCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
4771            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4772            targetMPS = targetCPU.detach().clone().to('mps')
4773
4774            # forward pass
4775            outputCPU = loss(inputCPU, targetCPU)
4776            outputMPS = loss(inputMPS, targetMPS)
4777            self.assertEqual(outputCPU, outputMPS)
4778
4779            # backward pass
4780            if reduction != 'none':
4781                # chose 2 just to make the grad_output > 1 in backward pass
4782                outputCPU.backward(gradient=torch.full_like(outputCPU, 2))
4783                outputMPS.backward(gradient=torch.full_like(outputMPS, 2))
4784                self.assertEqual(inputCPU.grad, inputMPS.grad)
4785
4786        helper([8, 5, 4], 'none')
4787        helper([7, 5, 2, 4], 'sum')
4788        # verify if changes in shape would cause cached graph lookup problems
4789        helper([7, 5, 2, 4, 6], 'sum')
4790        helper([8, 4, 5, 7, 6], 'mean')
4791
4792    def test_mse_loss_strided_output(self):
4793        # https://github.com/pytorch/pytorch/issues/124621
4794        lf = nn.MSELoss(reduction='none')
4795        model_cpu = nn.Sequential(
4796            nn.Conv1d(3, 3, 1),
4797        )
4798        model_mps = copy.deepcopy(model_cpu).to("mps")
4799
4800        x = torch.randn(128, 10, 3)
4801        x = x.permute(0, 2, 1)
4802
4803        x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
4804        x_mps = x_mps.permute(0, 2, 1)
4805
4806        y = model_cpu(x)
4807        y_mps = model_mps(x_mps)
4808
4809        y = y.permute(0, 2, 1)[:, :5, :]
4810        y_mps = y_mps.permute(0, 2, 1)[:, :5, :]
4811
4812        y_hat = torch.randn(128, 5, 3)
4813        y_hat_mps = y_hat.detach().clone().to("mps")
4814
4815        loss = lf(y, y_hat)
4816        loss_mps = lf(y_mps, y_hat_mps)
4817        self.assertEqual(loss, loss_mps)
4818
4819    # Binary Cross Enropy
4820    def test_bce_loss_simple(self):
4821        def helper(shape, reduction):
4822            # create the criterion
4823            loss = torch.nn.BCELoss(reduction=reduction)
4824
4825            # input and target must be within [0..1]
4826            input_t = np.random.random_sample(size=shape).astype(np.float32)
4827            target_t = np.random.random_sample(size=shape).astype(np.float32)
4828            inputCPU = torch.tensor(input_t, device='cpu', dtype=torch.float, requires_grad=True)
4829            targetCPU = torch.tensor(target_t, device='cpu', dtype=torch.float, requires_grad=False)
4830            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
4831            targetMPS = targetCPU.detach().clone().to('mps')
4832
4833            # forward pass
4834            outputCPU = loss(inputCPU, targetCPU)
4835            outputMPS = loss(inputMPS, targetMPS)
4836            self.assertEqual(outputCPU, outputMPS)
4837
4838            # backward pass
4839            if reduction != 'none':
4840                # chose 0.6 just to have the grad_output != 1
4841                outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
4842                outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
4843                self.assertEqual(inputCPU.grad, inputMPS.grad)
4844
4845        helper([8, 5, 4], 'none')
4846        helper([7, 5, 2, 4], 'sum')
4847        # verify if changes in shape would cause cached graph lookup problems
4848        helper([7, 5, 2, 4, 6], 'sum')
4849        helper([8, 4, 5, 7, 6], 'mean')
4850        helper([1, 1, 32, 32], 'mean')
4851
4852    def test_bce_loss_always_nonnegative(self):
4853        target = torch.ones(5, device='mps')
4854        input = torch.ones(5, device='mps')
4855        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4856
4857        target = torch.zeros(5, device='mps')
4858        input = torch.zeros(5, device='mps')
4859        self.assertEqual((nn.BCELoss()(input, target) < 0).sum(), 0)
4860
4861    def test_bce_loss_size_mismatch(self):
4862        bceloss = nn.BCELoss()
4863        a = torch.rand(25, device='mps')
4864        b = torch.rand(25, 1, device='mps')
4865        with self.assertRaisesRegex(ValueError, r'Using a target size \('):
4866            bceloss(a, b)
4867
4868    def test_bce_with_logits_gives_same_result_as_sigmoid_and_bce_loss_large_tensors_with_grad(self):
4869        x_size = 1024
4870        y_size = 256
4871        target = torch.rand(x_size, y_size, device='mps')
4872
4873        for reduction in ['none', 'mean', 'sum']:
4874            output_sig = torch.rand(x_size, y_size, device='mps') - 0.5
4875            output_logits = output_sig.clone().detach()
4876
4877            output_sig.requires_grad = True
4878            output_logits.requires_grad = True
4879            weight = torch.rand(y_size, device='mps')
4880
4881            loss_sig = nn.BCELoss(weight, reduction=reduction)(
4882                torch.sigmoid(output_sig), target
4883            )
4884            loss_logits = nn.BCEWithLogitsLoss(weight, reduction=reduction)(
4885                output_logits, target
4886            )
4887
4888            self.assertEqual(loss_logits, loss_sig)
4889
4890            if reduction == 'none':
4891                grad = torch.rand(x_size, y_size, device='mps')
4892                loss_sig.backward(grad)
4893                loss_logits.backward(grad)
4894            else:
4895                loss_sig.backward()
4896                loss_logits.backward()
4897
4898            self.assertEqual(output_sig.grad, output_logits.grad)
4899
4900    def test_bce_with_logits_has_correct_grad_at_zero(self):
4901        output = torch.zeros(3, 1, requires_grad=True, device='mps')
4902        target = torch.zeros(3, 1, device='mps')
4903        nn.BCEWithLogitsLoss(reduction='sum')(output, target).backward()
4904        expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4905        self.assertEqual(output.grad, expected_grad)
4906
4907    def test_bce_with_logits_broadcasts_weights(self):
4908        target = torch.rand(16, 4, device='mps')
4909        output = torch.rand(16, 4, device='mps') - 0.5
4910
4911        weight = torch.rand(4, device='mps')
4912        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4913
4914        weight = weight.expand(16, 4).contiguous()
4915        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4916
4917        self.assertEqual(out1, out2)
4918
4919        weight = torch.rand(16, 1, device='mps')
4920        out1 = nn.BCEWithLogitsLoss(weight)(output, target)
4921
4922        weight = weight.expand(16, 4).contiguous()
4923        out2 = nn.BCEWithLogitsLoss(weight)(output, target)
4924
4925        self.assertEqual(out1, out2)
4926
4927    def test_bce_with_logits_ones_in_pos_weights_are_the_same_as_none(self):
4928        target = torch.rand(64, 4, device='mps')
4929        output = torch.rand(64, 4, device='mps') - 0.5
4930        pos_weight = torch.ones(64, 4, device='mps')
4931
4932        self.assertEqual(nn.BCEWithLogitsLoss()(output, target),
4933                         nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target))
4934
4935    def test_bce_with_logits_broadcasts_pos_weights(self):
4936        target = torch.rand(64, 4, device='mps')
4937        output = torch.rand(64, 4, device='mps') - 0.5
4938        pos_weight = torch.rand(4, device='mps')
4939        out1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4940
4941        pos_weight1 = pos_weight.expand(1, 4)
4942        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight1)(output, target)
4943
4944        pos_weight2 = pos_weight.expand(64, 4)
4945        out3 = nn.BCEWithLogitsLoss(pos_weight=pos_weight2)(output, target)
4946
4947        self.assertEqual(out1, out2)
4948        self.assertEqual(out1, out3)
4949
4950    def test_bce_with_logits_with_pos_weight_has_correct_grad_at_zero(self):
4951        output = torch.zeros(3, 1, requires_grad=True, device='mps')
4952        target = torch.zeros(3, 1, device='mps')
4953        pos_weight = torch.ones(3, 1, device='mps')
4954        nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='sum')(output, target).backward()
4955        expected_grad = torch.empty(3, 1, device='mps').fill_(0.5)
4956        grad = output.grad
4957        self.assertEqual(grad, expected_grad)
4958
4959    def test_bce_with_logits_stability(self):
4960        output = torch.tensor([0., -120.], device='mps')
4961        target = torch.tensor([0., 1.], device='mps')
4962        pos_weight = torch.tensor([1., 1.], device='mps')
4963
4964        out1 = nn.BCEWithLogitsLoss()(output, target)
4965        self.assertTrue(torch.isfinite(out1).all().item())
4966
4967        out2 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)(output, target)
4968        self.assertTrue(torch.isfinite(out2).all().item())
4969
4970    def test_bce_loss_broadcasts_weights(self):
4971        sigmoid = nn.Sigmoid()
4972        target = torch.rand(16, 4, device='mps')
4973        output = torch.rand(16, 4, device='mps') - 0.5
4974
4975        weight = torch.rand(4, device='mps')
4976        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4977
4978        weight = weight.expand(16, 4).contiguous()
4979        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4980
4981        self.assertEqual(out1, out2)
4982
4983        weight = torch.rand(16, 1, device='mps')
4984        out1 = nn.BCELoss(weight)(sigmoid(output), target)
4985
4986        weight = weight.expand(16, 4).contiguous()
4987        out2 = nn.BCELoss(weight)(sigmoid(output), target)
4988
4989        self.assertEqual(out1, out2)
4990
4991    def test_cross_entropy_loss(self):
4992        # Regression test for https://github.com/pytorch/pytorch/issues/116095
4993        loss = nn.CrossEntropyLoss()
4994        pred = torch.randn(3, 5, requires_grad=True, dtype=torch.float16, device='mps')
4995        target = torch.ones(3, dtype=torch.long, device='mps')
4996        output = loss(pred, target)
4997        output.backward()
4998
4999    def test_log_softmax(self):
5000        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
5001        cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
5002        mps_x = torch.tensor(values, device='mps', requires_grad=True)
5003
5004        cpu_log_softmax = F.log_softmax(cpu_x, dim=0)
5005        mps_log_softmax = F.log_softmax(mps_x, dim=0)
5006        self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
5007
5008        cpu_grad = torch.ones_like(cpu_log_softmax)
5009        mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5010
5011        cpu_log_softmax.backward(gradient=cpu_grad)
5012        mps_log_softmax.backward(gradient=mps_grad)
5013
5014        self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
5015
5016    def test_log_softmax_large_numbers(self):
5017        values = [
5018            [10.0, 100.0, 1000.0, 10000.0, 100000.0, 1000000.0],
5019            [-10.0, -100.0, -1000.0, -10000.0, -100000.0, -1000000.0]
5020        ]
5021        cpu_x = torch.tensor(values, device='cpu', requires_grad=True)
5022        mps_x = torch.tensor(values, device='mps', requires_grad=True)
5023
5024        cpu_log_softmax = F.log_softmax(cpu_x, dim=-1)
5025        mps_log_softmax = F.log_softmax(mps_x, dim=-1)
5026        self.assertEqual(cpu_log_softmax, mps_log_softmax.to('cpu'))
5027
5028        cpu_grad = torch.ones_like(cpu_log_softmax)
5029        mps_grad = torch.ones_like(cpu_log_softmax).to('mps')
5030
5031        cpu_log_softmax.backward(gradient=cpu_grad)
5032        mps_log_softmax.backward(gradient=mps_grad)
5033
5034        self.assertEqual(cpu_x.grad, mps_x.grad.to('cpu'))
5035
5036    def test_eq(self):
5037        values1 = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
5038        values2 = [[[1.0, 2.0, 15.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [0.0, 11.0, 12.0]]]
5039        mps_x = torch.tensor(values1, device='mps')
5040        mps_y = torch.tensor(values2, device='mps')
5041        cpu_x = torch.tensor(values1, device='cpu')
5042        cpu_y = torch.tensor(values2, device='cpu')
5043        result_mps = torch.eq(mps_x, mps_y)
5044        result_cpu = torch.eq(cpu_x, cpu_y)
5045
5046        self.assertEqual(result_cpu, result_mps.to('cpu'))
5047
5048    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
5049    def test_signed_vs_unsigned_comparison(self):
5050        cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
5051        mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
5052        # in the comparison of signed vs. unsigned we should always cast to unsigned
5053        self.assertEqual(cpu_x == -1, mps_x == -1)
5054        self.assertEqual(cpu_x > -1, mps_x > -1)
5055        self.assertEqual(cpu_x < -1, mps_x < -1)
5056
5057    def test_eq_int64(self):
5058        values1 = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
5059        values2 = [[[1, 2, 15], [4, 5, 6]], [[7, 8, 9], [0, 11, 12]]]
5060        mps_x = torch.tensor(values1, device='mps')
5061        mps_y = torch.tensor(values2, device='mps')
5062        cpu_x = torch.tensor(values1, device='cpu')
5063        cpu_y = torch.tensor(values2, device='cpu')
5064        result_mps = torch.eq(mps_x, mps_y)
5065        result_cpu = torch.eq(cpu_x, cpu_y)
5066
5067        self.assertEqual(result_cpu, result_mps.to('cpu'))
5068
5069    def test_ne(self):
5070        def helper(shape):
5071            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5072            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5073            mps_x = cpu_x.detach().clone().to('mps')
5074            mps_y = cpu_y.detach().clone().to('mps')
5075            result_mps = torch.ne(mps_x, mps_y)
5076            result_cpu = torch.ne(cpu_x, cpu_y)
5077
5078            self.assertEqual(result_cpu, result_mps.to('cpu'))
5079
5080        helper((2, 3, 4, 5))
5081
5082    def test_ne_scalar(self):
5083        def helper(shape):
5084            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5085            mps_x = cpu_x.detach().clone().to('mps')
5086            result_mps = torch.ne(mps_x, 0.0)
5087            result_cpu = torch.ne(cpu_x, 0.0)
5088
5089            self.assertEqual(result_cpu, result_mps.to('cpu'))
5090
5091        helper((2, 3, 4, 5))
5092
5093    def test_lt(self):
5094        def helper(shape):
5095            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5096            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5097            mps_x = cpu_x.detach().clone().to('mps')
5098            mps_y = cpu_y.detach().clone().to('mps')
5099            result_mps = torch.lt(mps_x, mps_y)
5100            result_cpu = torch.lt(cpu_x, cpu_y)
5101
5102            self.assertEqual(result_cpu, result_mps.to('cpu'))
5103
5104        helper((2, 3, 4, 5))
5105
5106    def test_lt_scalar(self):
5107        def helper(shape):
5108            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5109            mps_x = cpu_x.detach().clone().to('mps')
5110            result_mps = torch.lt(mps_x, 0.0)
5111            result_cpu = torch.lt(cpu_x, 0.0)
5112
5113            self.assertEqual(result_cpu, result_mps.to('cpu'))
5114
5115        helper((2, 3, 4, 5))
5116
5117    def test_le(self):
5118        def helper(shape):
5119            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5120            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5121            mps_x = cpu_x.detach().clone().to('mps')
5122            mps_y = cpu_y.detach().clone().to('mps')
5123            result_mps = torch.le(mps_x, mps_y)
5124            result_cpu = torch.le(cpu_x, cpu_y)
5125
5126            self.assertEqual(result_cpu, result_mps.to('cpu'))
5127
5128        helper((2, 3, 4, 5))
5129
5130    def test_le_scalar(self):
5131        def helper(shape):
5132            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5133            mps_x = cpu_x.detach().clone().to('mps')
5134            result_mps = torch.le(mps_x, 0.0)
5135            result_cpu = torch.le(cpu_x, 0.0)
5136
5137            self.assertEqual(result_cpu, result_mps.to('cpu'))
5138
5139        helper((2, 3, 4, 5))
5140
5141    def test_ge(self):
5142        def helper(shape):
5143            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5144            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5145            mps_x = cpu_x.detach().clone().to('mps')
5146            mps_y = cpu_y.detach().clone().to('mps')
5147            result_mps = torch.ge(mps_x, mps_y)
5148            result_cpu = torch.ge(cpu_x, cpu_y)
5149
5150            self.assertEqual(result_cpu, result_mps.to('cpu'))
5151
5152        helper((2, 3, 4, 5))
5153
5154    def test_ge_scalar(self):
5155        def helper(shape):
5156            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5157            mps_x = cpu_x.detach().clone().to('mps')
5158            result_mps = torch.ge(mps_x, 0.0)
5159            result_cpu = torch.ge(cpu_x, 0.0)
5160
5161            self.assertEqual(result_cpu, result_mps.to('cpu'))
5162
5163        helper((2, 3, 4, 5))
5164
5165    def test_gt(self):
5166        def helper(shape):
5167            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5168            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float)
5169            mps_x = cpu_x.detach().clone().to('mps')
5170            mps_y = cpu_y.detach().clone().to('mps')
5171            result_mps = torch.gt(mps_x, mps_y)
5172            result_cpu = torch.gt(cpu_x, cpu_y)
5173
5174            self.assertEqual(result_cpu, result_mps.to('cpu'))
5175
5176        helper((2, 3, 4, 5))
5177
5178    def test_gt_scalar(self):
5179        def helper(shape):
5180            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
5181            mps_x = cpu_x.detach().clone().to('mps')
5182            result_mps = torch.gt(mps_x, 0.0)
5183            result_cpu = torch.gt(cpu_x, 0.0)
5184
5185            self.assertEqual(result_cpu, result_mps.to('cpu'))
5186
5187        helper((2, 3, 4, 5))
5188
5189    def test_argmax(self):
5190        # https://github.com/pytorch/pytorch/issues/98191
5191        cpu_tensor = torch.tensor([[0, 1], [2, 1], [1, 0]])
5192        res_cpu = torch.argmax(cpu_tensor, dim=1)
5193
5194        mps_tensor = cpu_tensor.to(torch.device('mps'))
5195        res_mps = torch.argmax(mps_tensor, dim=1)
5196        self.assertEqual(res_cpu, res_mps)
5197
5198        # https://github.com/pytorch/pytorch/issues/92311
5199        mps_tensor = torch.randn(10, 2, device='mps', dtype=torch.float32)
5200        cpu_tensor = mps_tensor.detach().clone().cpu()
5201
5202        res_mps = torch.argmax(mps_tensor, dim=1)
5203        res_cpu = torch.argmax(cpu_tensor, dim=1)
5204        self.assertEqual(res_cpu, res_mps)
5205
5206    # Test forward argmin argmax
5207    def test_argmin_argmax(self):
5208        def helper(n, c, h, w, reduction_type, dtype=torch.float32):
5209            if reduction_type == "max":
5210                arg_reduction_fn = torch.argmax
5211            else:
5212                arg_reduction_fn = torch.argmin
5213
5214            cpu_x = None
5215            x = None
5216            if (dtype not in [torch.float32, torch.bool]):
5217                cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5218                x = cpu_x.detach().clone().to('mps')
5219            elif (dtype == torch.bool):
5220                cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5221                x = cpu_x.detach().clone().to('mps')
5222            else:
5223                cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5224                x = cpu_x.detach().clone().to('mps').requires_grad_()
5225
5226            y = arg_reduction_fn(x)
5227            ref_y = arg_reduction_fn(cpu_x)
5228            self.assertEqual(y, ref_y)
5229
5230            y_0 = arg_reduction_fn(x, dim=0)
5231            refy_0 = arg_reduction_fn(cpu_x, dim=0)
5232            self.assertEqual(y_0, refy_0)
5233
5234            y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
5235            refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
5236            self.assertEqual(y_0dim, refy_0dim)
5237
5238            y_1 = arg_reduction_fn(x, dim=1)
5239            refy_1 = arg_reduction_fn(cpu_x, dim=1)
5240            self.assertEqual(y_1, refy_1)
5241
5242            y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
5243            refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
5244            self.assertEqual(y_1dim, refy_1dim)
5245
5246            y_2 = arg_reduction_fn(x, dim=2)
5247            refy_2 = arg_reduction_fn(cpu_x, dim=2)
5248            self.assertEqual(y_2, refy_2)
5249
5250            y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
5251            refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
5252            self.assertEqual(y_2dim, refy_2dim)
5253
5254            y_3 = arg_reduction_fn(x, dim=3)
5255            refy_3 = arg_reduction_fn(cpu_x, dim=3)
5256            self.assertEqual(y_3, refy_3)
5257
5258            y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
5259            refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
5260            self.assertEqual(y_3dim, refy_3dim)
5261
5262        helper(2, 8, 4, 4, "max", torch.float32)
5263        helper(2, 8, 4, 4, "max", torch.int32)
5264        helper(2, 8, 4, 4, "max", torch.float16)
5265        helper(2, 8, 4, 4, "max", torch.int64)
5266        helper(2, 8, 4, 4, "min", torch.float32)
5267        helper(2, 8, 4, 4, "min", torch.int32)
5268        helper(2, 8, 4, 4, "min", torch.float16)
5269        helper(2, 8, 4, 4, "min", torch.int64)
5270
5271    @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
5272    def test_reduction_sum_max_long_val(self):
5273        x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
5274        x_cpu = x_mps.detach().clone().cpu()
5275
5276        res_mps = torch.sum(x_mps)
5277        res_cpu = torch.sum(x_cpu)
5278        self.assertEqual(res_mps, res_cpu)
5279
5280    # Test forward max
5281    # Note - don't test grad now
5282    def test_max_el(self):
5283        def helper(n, c, h, w, dtype=torch.float32):
5284
5285            if (dtype not in [torch.float32, torch.bool]):
5286                cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5287                x = cpu_x.detach().clone().to('mps')
5288            elif (dtype == torch.bool):
5289                cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5290                x = cpu_x.detach().clone().to('mps')
5291            else:
5292                cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5293                x = cpu_x.detach().clone().to('mps')
5294
5295            ref_y = torch.max(cpu_x)
5296            y = torch.max(x)
5297            self.assertEqual(y, ref_y)
5298
5299            for dim in [0, 1, 2, 3]:
5300                for keepdim in [True, False]:
5301                    y, idx = torch.max(x, dim=dim, keepdim=keepdim)
5302                    refy, refidx = torch.max(cpu_x, dim=dim, keepdim=keepdim)
5303                    self.assertEqual(y, refy)
5304                    self.assertEqual(idx, refidx)
5305
5306            y_0 = torch.ones(c, h, w, device='mps', dtype=dtype)
5307            idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5308            torch.max(x, dim=0, out=(y_0, idx_0))
5309            refy_0, refidx_0 = torch.max(cpu_x, dim=0)
5310            self.assertEqual(y_0, refy_0)
5311            self.assertEqual(idx_0, refidx_0)
5312
5313            y_0dim = torch.ones(1, c, h, w, device='mps', dtype=dtype)
5314            idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5315            torch.max(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5316            refy_0dim, refidx_0dim = torch.max(cpu_x, dim=0, keepdim=True)
5317            self.assertEqual(y_0dim, refy_0dim)
5318            self.assertEqual(idx_0dim, refidx_0dim)
5319
5320            y_1 = torch.ones(n, h, w, device='mps', dtype=dtype)
5321            idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5322            torch.max(x, dim=1, out=(y_1, idx_1))
5323            refy_1, refidx_1 = torch.max(cpu_x, dim=1)
5324            self.assertEqual(y_1, refy_1)
5325            self.assertEqual(idx_1, refidx_1)
5326
5327            y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=dtype)
5328            idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5329            torch.max(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5330            refy_1dim, refidx_1dim = torch.max(cpu_x, keepdim=True, dim=1)
5331            self.assertEqual(y_1dim, refy_1dim)
5332            self.assertEqual(idx_1dim, refidx_1dim)
5333
5334            y_2 = torch.ones(n, c, w, device='mps', dtype=dtype)
5335            idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5336            torch.max(x, dim=2, out=(y_2, idx_2))
5337            refy_2, refidx_2 = torch.max(cpu_x, dim=2)
5338            self.assertEqual(y_2, refy_2)
5339            self.assertEqual(idx_2, refidx_2)
5340
5341            y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=dtype)
5342            idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5343            torch.max(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5344            refy_2dim, refidx_2dim = torch.max(cpu_x, dim=2, keepdim=True,)
5345            self.assertEqual(y_2dim, refy_2dim)
5346            self.assertEqual(idx_2dim, refidx_2dim)
5347
5348            y_3 = torch.ones(n, c, h, device='mps', dtype=dtype)
5349            idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5350            torch.max(x, dim=3, out=(y_3, idx_3))
5351            refy_3, refidx_3 = torch.max(cpu_x, dim=3)
5352            self.assertEqual(y_3, refy_3)
5353            self.assertEqual(idx_3, refidx_3)
5354
5355            y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=dtype)
5356            idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5357            torch.max(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5358            refy_3dim, refidx_3dim = torch.max(cpu_x, dim=3, keepdim=True,)
5359            self.assertEqual(y_3dim, refy_3dim)
5360            self.assertEqual(idx_3dim, refidx_3dim)
5361
5362        helper(2, 8, 4, 5, torch.float32)
5363        helper(2, 8, 4, 5, torch.int32)
5364        # helper(2, 8, 4, 5, torch.int64)
5365
5366    def test_median(self):
5367        def helper_dtype_int32(n1, n2, n3):
5368            cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
5369            mps_x = cpu_x.detach().clone().to('mps')
5370
5371            result_cpu = torch.median(cpu_x)
5372            result_mps = torch.median(mps_x)
5373
5374            self.assertEqual(result_cpu, result_mps)
5375
5376            for dim in [0, 1, 2]:
5377                for keepdim in [True, False]:
5378                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5379                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5380                    self.assertEqual(y, refy)
5381                    self.assertEqual(idx, refidx)
5382
5383        def helper_dtype_float32(n1, n2, n3):
5384            cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
5385            mps_x = cpu_x.detach().clone().to('mps')
5386
5387            result_cpu = torch.median(cpu_x)
5388            result_mps = torch.median(mps_x)
5389
5390            self.assertEqual(result_cpu, result_mps)
5391
5392            for dim in [0, 1, 2]:
5393                for keepdim in [True, False]:
5394                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
5395                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
5396                    self.assertEqual(y, refy)
5397                    self.assertEqual(idx, refidx)
5398
5399        helper_dtype_int32(10, 10, 10)  # median at even place
5400        helper_dtype_int32(3, 3, 3)  # median at odd place
5401        helper_dtype_int32(1, 1, 1)
5402        helper_dtype_int32(1, 2, 3)
5403        helper_dtype_float32(10, 10, 10)
5404        helper_dtype_float32(3, 3, 3)
5405        helper_dtype_float32(1, 1, 1)
5406
5407    def test_any(self):
5408        def helper(shape):
5409            input_xs = []
5410            prod = 1
5411
5412            for i in range(len(shape)):
5413                prod *= shape[i]
5414            input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5415            input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5416            input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5417            input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5418            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5419            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5420            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5421            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5422            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5423            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5424
5425            for i, cpu_x in enumerate(input_xs):
5426                x = cpu_x.detach().clone().to('mps')
5427                y = torch.any(x)
5428                ref_y = torch.any(cpu_x)
5429                self.assertEqual(y, ref_y)
5430
5431                y_0 = torch.any(x, dim=0)
5432                refy_0 = torch.any(cpu_x, dim=0)
5433                self.assertEqual(y_0, refy_0)
5434
5435                y_0dim = torch.any(x, dim=0, keepdim=True)
5436                refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5437                self.assertEqual(y_0dim, refy_0dim)
5438
5439                y_0dim = torch.any(x, dim=0, keepdim=True)
5440                refy_0dim = torch.any(cpu_x, dim=0, keepdim=True)
5441                self.assertEqual(y_0dim, refy_0dim)
5442
5443                y_1 = torch.any(x, dim=1)
5444                refy_1 = torch.any(cpu_x, dim=1)
5445                self.assertEqual(y_1, refy_1)
5446
5447                y_1dim = torch.any(x, dim=1, keepdim=True)
5448                refy_1dim = torch.any(cpu_x, dim=1, keepdim=True)
5449                self.assertEqual(y_1dim, refy_1dim)
5450
5451                if (len(shape) > 2):
5452                    y_2 = torch.any(x, dim=2)
5453                    refy_2 = torch.any(cpu_x, dim=2)
5454                    self.assertEqual(y_2, refy_2)
5455
5456                    y_2dim = torch.any(x, dim=2, keepdim=True)
5457                    refy_2dim = torch.any(cpu_x, dim=2, keepdim=True)
5458                    self.assertEqual(y_2dim, refy_2dim)
5459
5460                    y_3 = torch.any(x, dim=3)
5461                    refy_3 = torch.any(cpu_x, dim=3)
5462                    self.assertEqual(y_3, refy_3)
5463
5464                    y_3dim = torch.any(x, dim=3, keepdim=True)
5465                    refy_3dim = torch.any(cpu_x, dim=3, keepdim=True)
5466                    self.assertEqual(y_3dim, refy_3dim)
5467        helper((1, 1, 1, 1))
5468        helper((1, 1, 3, 3))
5469        helper((7, 13))
5470        helper((2, 8, 4, 5))
5471
5472    def test_reduction_ops_5D(self):
5473        def helper(fn, dim):
5474            shape = (1, 1, 2, 1, 1)
5475            x_cpu = fn(torch.zeros(shape), dim=dim)
5476            x_mps = fn(torch.zeros(shape, device="mps"), dim=dim)
5477            self.assertEqual(x_cpu, x_mps.cpu())
5478        for fn in [torch.any, torch.all]:
5479            for dim in range(0, 4):
5480                helper(fn, dim)
5481
5482        # 6D tensor reductions
5483        # Regression test for https://github.com/pytorch/pytorch/issues/95538
5484        x = (torch.rand(2, 3, 4, 3, 4, 2, device="mps") - .5).relu()
5485        self.assertEqual(x.all(), x.cpu().all())
5486        for i in range(-5, 6):
5487            self.assertEqual(x.all(dim=i), x.cpu().all(dim=i))
5488
5489    def test_all(self):
5490        def helper(shape):
5491            input_xs = []
5492            prod = 1
5493
5494            for i in range(len(shape)):
5495                prod *= shape[i]
5496            input_xs.append(torch.randn(prod, dtype=torch.float).reshape(shape))
5497            input_xs.append(torch.arange(0, prod, dtype=torch.float).reshape(shape))
5498            input_xs.append(torch.ones(prod, dtype=torch.float).reshape(shape))
5499            input_xs.append(torch.zeros(prod, dtype=torch.float).reshape(shape))
5500            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape))
5501            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape))
5502            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape))
5503            input_xs.append(torch.arange(0, prod, dtype=torch.int).reshape(shape).bool())
5504            input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool())
5505            input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool())
5506
5507            for i, cpu_x in enumerate(input_xs):
5508                x = cpu_x.detach().clone().to('mps')
5509                y = torch.all(x)
5510                ref_y = torch.all(cpu_x)
5511                self.assertEqual(y, ref_y)
5512
5513                y_0 = torch.all(x, dim=0)
5514                refy_0 = torch.all(cpu_x, dim=0)
5515                self.assertEqual(y_0, refy_0)
5516
5517                y_0dim = torch.all(x, dim=0, keepdim=True)
5518                refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5519                self.assertEqual(y_0dim, refy_0dim)
5520
5521                y_0dim = torch.all(x, dim=0, keepdim=True)
5522                refy_0dim = torch.all(cpu_x, dim=0, keepdim=True)
5523                self.assertEqual(y_0dim, refy_0dim)
5524
5525                y_1 = torch.all(x, dim=1)
5526                refy_1 = torch.all(cpu_x, dim=1)
5527                self.assertEqual(y_1, refy_1)
5528
5529                y_1dim = torch.all(x, dim=1, keepdim=True)
5530                refy_1dim = torch.all(cpu_x, dim=1, keepdim=True)
5531                self.assertEqual(y_1dim, refy_1dim)
5532                if (len(shape) > 2):
5533                    y_2 = torch.all(x, dim=2)
5534                    refy_2 = torch.all(cpu_x, dim=2)
5535                    self.assertEqual(y_2, refy_2)
5536
5537                    y_2dim = torch.all(x, dim=2, keepdim=True)
5538                    refy_2dim = torch.all(cpu_x, dim=2, keepdim=True)
5539                    self.assertEqual(y_2dim, refy_2dim)
5540
5541                    y_3 = torch.all(x, dim=3)
5542                    refy_3 = torch.all(cpu_x, dim=3)
5543                    self.assertEqual(y_3, refy_3)
5544
5545                    y_3dim = torch.all(x, dim=3, keepdim=True)
5546                    refy_3dim = torch.all(cpu_x, dim=3, keepdim=True)
5547                    self.assertEqual(y_3dim, refy_3dim)
5548
5549        helper((1, 1, 1, 1))
5550        helper((1, 1, 3, 3))
5551        helper((7, 13))
5552        helper((2, 8, 4, 5))
5553        # Empty tensor
5554        x_cpu = torch.tensor([], dtype=torch.bool)
5555        x_mps = x_cpu.to("mps")
5556        self.assertEqual(x_cpu.all(), x_mps.all().cpu())
5557
5558    # Test forward min
5559    def test_min_el(self):
5560        def helper(n, c, h, w):
5561            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
5562            x = cpu_x.detach().clone().to('mps')
5563
5564            y = torch.min(x)
5565            ref_y = torch.min(cpu_x)
5566            self.assertEqual(y, ref_y)
5567
5568            y_0, idx_0 = torch.min(x, dim=0)
5569            refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5570            self.assertEqual(y_0, refy_0)
5571            self.assertEqual(idx_0, refidx_0)
5572
5573            y_0 = torch.ones(c, h, w, device='mps', dtype=torch.float)
5574            idx_0 = torch.ones(c, h, w, device='mps', dtype=torch.int64)
5575            torch.min(x, dim=0, out=(y_0, idx_0))
5576            refy_0, refidx_0 = torch.min(cpu_x, dim=0)
5577            self.assertEqual(y_0, refy_0)
5578            self.assertEqual(idx_0, refidx_0)
5579
5580            y_0dim, idx_0dim = torch.min(x, dim=0, keepdim=True)
5581            refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5582            self.assertEqual(y_0dim, refy_0dim)
5583            self.assertEqual(idx_0dim, refidx_0dim)
5584
5585            y_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.float)
5586            idx_0dim = torch.ones(1, c, h, w, device='mps', dtype=torch.int64)
5587            torch.min(x, dim=0, keepdim=True, out=(y_0dim, idx_0dim))
5588            refy_0dim, refidx_0dim = torch.min(cpu_x, dim=0, keepdim=True)
5589            self.assertEqual(y_0dim, refy_0dim)
5590            self.assertEqual(idx_0dim, refidx_0dim)
5591
5592            y_1, idx_1 = torch.min(x, dim=1)
5593            refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5594            self.assertEqual(y_1, refy_1)
5595            self.assertEqual(idx_1, refidx_1)
5596
5597            y_1 = torch.ones(n, h, w, device='mps', dtype=torch.float)
5598            idx_1 = torch.ones(n, h, w, device='mps', dtype=torch.int64)
5599            torch.min(x, dim=1, out=(y_1, idx_1))
5600            refy_1, refidx_1 = torch.min(cpu_x, dim=1)
5601            self.assertEqual(y_1, refy_1)
5602            self.assertEqual(idx_1, refidx_1)
5603
5604            y_1dim, idx_1dim = torch.min(x, dim=1, keepdim=True)
5605            refy_1dim, refidx_1dim = torch.min(cpu_x, dim=1, keepdim=True)
5606            self.assertEqual(y_1dim, refy_1dim)
5607            self.assertEqual(idx_1dim, refidx_1dim)
5608
5609            y_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.float)
5610            idx_1dim = torch.ones(n, 1, h, w, device='mps', dtype=torch.int64)
5611            torch.min(x, dim=1, keepdim=True, out=(y_1dim, idx_1dim))
5612            refy_1dim, refidx_1dim = torch.min(cpu_x, keepdim=True, dim=1)
5613            self.assertEqual(y_1dim, refy_1dim)
5614            self.assertEqual(idx_1dim, refidx_1dim)
5615
5616            y_2, idx_2 = torch.min(x, dim=2)
5617            refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5618            self.assertEqual(y_2, refy_2)
5619            self.assertEqual(idx_2, refidx_2)
5620
5621            y_2 = torch.ones(n, c, w, device='mps', dtype=torch.float)
5622            idx_2 = torch.ones(n, c, w, device='mps', dtype=torch.int64)
5623            torch.min(x, dim=2, out=(y_2, idx_2))
5624            refy_2, refidx_2 = torch.min(cpu_x, dim=2)
5625            self.assertEqual(y_2, refy_2)
5626            self.assertEqual(idx_2, refidx_2)
5627
5628            y_2dim, idx_2dim = torch.min(x, dim=2, keepdim=True)
5629            refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True)
5630            self.assertEqual(y_2dim, refy_2dim)
5631            self.assertEqual(idx_2dim, refidx_2dim)
5632
5633            y_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.float)
5634            idx_2dim = torch.ones(n, c, 1, w, device='mps', dtype=torch.int64)
5635            torch.min(x, dim=2, keepdim=True, out=(y_2dim, idx_2dim))
5636            refy_2dim, refidx_2dim = torch.min(cpu_x, dim=2, keepdim=True,)
5637            self.assertEqual(y_2dim, refy_2dim)
5638            self.assertEqual(idx_2dim, refidx_2dim)
5639
5640            y_3, idx_3 = torch.min(x, dim=3)
5641            refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5642            self.assertEqual(y_3, refy_3)
5643            self.assertEqual(idx_3, refidx_3)
5644
5645            y_3 = torch.ones(n, c, h, device='mps', dtype=torch.float)
5646            idx_3 = torch.ones(n, c, h, device='mps', dtype=torch.int64)
5647            torch.min(x, dim=3, out=(y_3, idx_3))
5648            refy_3, refidx_3 = torch.min(cpu_x, dim=3)
5649            self.assertEqual(y_3, refy_3)
5650            self.assertEqual(idx_3, refidx_3)
5651
5652            y_3dim, idx_3dim = torch.min(x, dim=3, keepdim=True)
5653            refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True)
5654            self.assertEqual(y_3dim, refy_3dim)
5655            self.assertEqual(idx_3dim, refidx_3dim)
5656
5657            y_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.float)
5658            idx_3dim = torch.ones(n, c, h, 1, device='mps', dtype=torch.int64)
5659            torch.min(x, dim=3, keepdim=True, out=(y_3dim, idx_3dim))
5660            refy_3dim, refidx_3dim = torch.min(cpu_x, dim=3, keepdim=True,)
5661            self.assertEqual(y_3dim, refy_3dim)
5662            self.assertEqual(idx_3dim, refidx_3dim)
5663
5664        helper(2, 8, 4, 5)
5665
5666    # Test forward sum
5667    def test_sum(self):
5668        def helper(n, c, h, w, dtype=torch.float32):
5669            cpu_x = None
5670            x = None
5671            if (dtype not in [torch.float32, torch.bool]):
5672                cpu_x = torch.randint(50, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5673                x = cpu_x.detach().clone().to('mps')
5674            elif (dtype == torch.bool):
5675                cpu_x = torch.randint(2, (n, c, h, w), device='cpu', dtype=dtype, requires_grad=False)
5676                x = cpu_x.detach().clone().to('mps')
5677            else:
5678                cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
5679                x = cpu_x.detach().clone().to('mps').requires_grad_()
5680
5681            all_sum = torch.sum(x)
5682            all_sum_cpu = torch.sum(cpu_x)
5683
5684            self.assertEqual(all_sum, all_sum_cpu)
5685
5686            nil_dim_sum = torch.sum(x, dim=[])
5687            nil_dim_sum_cpu = torch.sum(cpu_x, dim=[])
5688
5689            self.assertEqual(nil_dim_sum, nil_dim_sum_cpu)
5690
5691            nil_dim_sum_keepdim = torch.sum(x, dim=[], keepdim=True)
5692            nil_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[], keepdim=True)
5693
5694            self.assertEqual(nil_dim_sum_keepdim, nil_dim_sum_cpu_keepdim)
5695
5696            zero_dim_sum = torch.sum(x, dim=[0])
5697            zero_dim_sum_cpu = torch.sum(cpu_x, dim=[0])
5698
5699            self.assertEqual(zero_dim_sum, zero_dim_sum_cpu)
5700
5701            zero_dim_sum_keepdim = torch.sum(x, dim=[0], keepdim=True)
5702            zero_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0], keepdim=True)
5703
5704            self.assertEqual(zero_dim_sum_keepdim, zero_dim_sum_cpu_keepdim)
5705
5706            zero_one_dim_sum = torch.sum(x, dim=[0, 1])
5707            zero_one_dim_sum_cpu = torch.sum(cpu_x, dim=[0, 1])
5708
5709            self.assertEqual(zero_one_dim_sum, zero_one_dim_sum_cpu)
5710
5711            zero_one_dim_sum_keepdim = torch.sum(x, dim=[0, 1], keepdim=True)
5712            zero_one_dim_sum_cpu_keepdim = torch.sum(cpu_x, dim=[0, 1], keepdim=True)
5713
5714            self.assertEqual(zero_one_dim_sum_keepdim, zero_one_dim_sum_cpu_keepdim)
5715
5716            two_three_dim_sum = torch.sum(x, dim=[2, 3])
5717            two_three_dim_sum_cpu = torch.sum(cpu_x, dim=[2, 3])
5718
5719            self.assertEqual(two_three_dim_sum, two_three_dim_sum_cpu)
5720
5721            two_three_keepdim_sum = torch.sum(x, dim=[2, 3], keepdim=True)
5722            two_three_dim_keepsum_cpu = torch.sum(cpu_x, dim=[2, 3], keepdim=True)
5723
5724            self.assertEqual(two_three_keepdim_sum, two_three_dim_keepsum_cpu)
5725
5726        helper(2, 8, 4, 5)
5727        helper(2, 8, 4, 5, dtype=torch.int32)
5728        helper(2, 8, 4, 5, dtype=torch.int64)
5729        helper(2, 8, 4, 5, dtype=torch.bool)
5730        # Regression test for https://github.com/pytorch/pytorch/issues/136132
5731        x = torch.ones(2, 4, 1, 30, 1, device='mps').sum(dim=-2)
5732        self.assertEqual(x.numel(), 8)
5733        self.assertEqual(x.max().item(), 30.0)
5734
5735    # Test forward prod
5736    def test_prod(self):
5737        def helper(shape, dtype=torch.float32):
5738            cpu_x = None
5739            x = None
5740            if (dtype not in [torch.float32, torch.bool]):
5741                cpu_x = torch.randint(1, 6, shape, device='cpu', dtype=dtype, requires_grad=False)
5742                x = cpu_x.detach().clone().to('mps')
5743            elif (dtype == torch.bool):
5744                cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
5745                x = cpu_x.detach().clone().to('mps')
5746            else:
5747                cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
5748                x = cpu_x.detach().clone().to('mps').requires_grad_()
5749
5750            all_prod = torch.prod(x)
5751            all_prod_cpu = torch.prod(cpu_x)
5752
5753            self.assertEqual(all_prod, all_prod_cpu)
5754
5755            for dim in range(len(shape)):
5756                dim_prod = torch.prod(x, dim=dim)
5757                dim_prod_cpu = torch.prod(cpu_x, dim=dim)
5758
5759                self.assertEqual(dim_prod, dim_prod_cpu)
5760
5761                dim_prod_keepdim = torch.prod(x, dim=dim, keepdim=True)
5762                dim_prod_cpu_keepdim = torch.prod(cpu_x, dim=dim, keepdim=True)
5763
5764                self.assertEqual(dim_prod_keepdim, dim_prod_cpu_keepdim)
5765
5766        for dtype in [torch.float32, torch.int32, torch.int64, torch.bool]:
5767            helper((2, 3), dtype)
5768
5769    # Test forward mean
5770    def test_mean(self):
5771        def helper(n, c, h, w):
5772            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=True)
5773            x = cpu_x.detach().clone().to('mps').requires_grad_()
5774
5775            all_mean = torch.mean(x)
5776            all_mean_cpu = torch.mean(cpu_x)
5777
5778            self.assertEqual(all_mean, all_mean_cpu)
5779
5780            nil_dim_mean = torch.mean(x, dim=[])
5781            nil_dim_mean_cpu = torch.mean(cpu_x, dim=[])
5782
5783            self.assertEqual(nil_dim_mean, nil_dim_mean_cpu)
5784
5785            nil_dim_mean_keepdim = torch.mean(x, dim=[], keepdim=True)
5786            nil_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[], keepdim=True)
5787
5788            self.assertEqual(nil_dim_mean_keepdim, nil_dim_mean_cpu_keepdim)
5789
5790            zero_dim_mean = torch.mean(x, dim=[0])
5791            zero_dim_mean_cpu = torch.mean(cpu_x, dim=[0])
5792
5793            self.assertEqual(zero_dim_mean, zero_dim_mean_cpu)
5794
5795            zero_dim_mean_keepdim = torch.mean(x, dim=[0], keepdim=True)
5796            zero_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0], keepdim=True)
5797
5798            self.assertEqual(zero_dim_mean_keepdim, zero_dim_mean_cpu_keepdim)
5799
5800            zero_one_dim_mean = torch.mean(x, dim=[0, 1])
5801            zero_one_dim_mean_cpu = torch.mean(cpu_x, dim=[0, 1])
5802
5803            self.assertEqual(zero_one_dim_mean, zero_one_dim_mean_cpu)
5804
5805            zero_one_dim_mean_keepdim = torch.mean(x, dim=[0, 1], keepdim=True)
5806            zero_one_dim_mean_cpu_keepdim = torch.mean(cpu_x, dim=[0, 1], keepdim=True)
5807
5808            self.assertEqual(zero_one_dim_mean_keepdim, zero_one_dim_mean_cpu_keepdim)
5809
5810            two_three_dim_mean = torch.mean(x, dim=[2, 3])
5811            two_three_dim_mean_cpu = torch.mean(cpu_x, dim=[2, 3])
5812
5813            self.assertEqual(two_three_dim_mean, two_three_dim_mean_cpu)
5814
5815            two_three_keepdim_mean = torch.mean(x, dim=[2, 3], keepdim=True)
5816            two_three_dim_keepmean_cpu = torch.mean(cpu_x, dim=[2, 3], keepdim=True)
5817
5818            self.assertEqual(two_three_keepdim_mean, two_three_dim_keepmean_cpu)
5819
5820        helper(2, 8, 4, 5)
5821
5822    # Test std
5823    def test_std(self):
5824        def helper(shape):
5825            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5826            x = cpu_x.detach().clone().to('mps')
5827
5828            all_std = torch.std(x, unbiased=False)
5829            all_std_cpu = torch.std(cpu_x, unbiased=False)
5830
5831            self.assertEqual(all_std, all_std_cpu)
5832
5833            nil_dim_std = torch.std(x, dim=[], unbiased=False)
5834            nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=False)
5835
5836            self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5837
5838            nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=False)
5839            nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=False)
5840
5841            self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5842
5843            zero_dim_std = torch.std(x, dim=[0], unbiased=False)
5844            zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=False)
5845
5846            self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5847
5848            zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=False)
5849            zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=False)
5850
5851            self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5852
5853            zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=False)
5854            zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=False)
5855
5856            self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5857
5858            zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=False)
5859            zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=False)
5860
5861            self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5862
5863            two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=False)
5864            two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=False)
5865
5866            self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5867
5868            two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=False)
5869            two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=False)
5870
5871            self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5872
5873            all_std = torch.std(x, unbiased=True)
5874            all_std_cpu = torch.std(cpu_x, unbiased=True)
5875
5876            self.assertEqual(all_std, all_std_cpu)
5877
5878            nil_dim_std = torch.std(x, dim=[], unbiased=True)
5879            nil_dim_std_cpu = torch.std(cpu_x, dim=[], unbiased=True)
5880
5881            self.assertEqual(nil_dim_std, nil_dim_std_cpu)
5882
5883            nil_dim_std_keepdim = torch.std(x, dim=[], keepdim=True, unbiased=True)
5884            nil_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[], keepdim=True, unbiased=True)
5885
5886            self.assertEqual(nil_dim_std_keepdim, nil_dim_std_cpu_keepdim)
5887
5888            zero_dim_std = torch.std(x, dim=[0], unbiased=True)
5889            zero_dim_std_cpu = torch.std(cpu_x, dim=[0], unbiased=True)
5890
5891            self.assertEqual(zero_dim_std, zero_dim_std_cpu)
5892
5893            zero_dim_std_keepdim = torch.std(x, dim=[0], keepdim=True, unbiased=True)
5894            zero_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0], keepdim=True, unbiased=True)
5895
5896            self.assertEqual(zero_dim_std_keepdim, zero_dim_std_cpu_keepdim)
5897
5898            zero_one_dim_std = torch.std(x, dim=[0, 1], unbiased=True)
5899            zero_one_dim_std_cpu = torch.std(cpu_x, dim=[0, 1], unbiased=True)
5900
5901            self.assertEqual(zero_one_dim_std, zero_one_dim_std_cpu)
5902
5903            zero_one_dim_std_keepdim = torch.std(x, dim=[0, 1], keepdim=True, unbiased=True)
5904            zero_one_dim_std_cpu_keepdim = torch.std(cpu_x, dim=[0, 1], keepdim=True, unbiased=True)
5905
5906            self.assertEqual(zero_one_dim_std_keepdim, zero_one_dim_std_cpu_keepdim)
5907
5908            two_three_dim_std = torch.std(x, dim=[2, 3], unbiased=True)
5909            two_three_dim_std_cpu = torch.std(cpu_x, dim=[2, 3], unbiased=True)
5910
5911            self.assertEqual(two_three_dim_std, two_three_dim_std_cpu)
5912
5913            two_three_keepdim_std = torch.std(x, dim=[2, 3], keepdim=True, unbiased=True)
5914            two_three_dim_keepstd_cpu = torch.std(cpu_x, dim=[2, 3], keepdim=True, unbiased=True)
5915
5916            self.assertEqual(two_three_keepdim_std, two_three_dim_keepstd_cpu)
5917
5918        helper((4, 5, 6, 7))
5919        # verify if a change in shape of input would cause problems with graph caching
5920        helper((9, 5, 6, 7))
5921
5922    # Test var
5923    def test_var_simple(self):
5924        def helper():
5925
5926            shape = [2, 3, 4, 5]
5927
5928            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
5929            x = cpu_x.detach().clone().to('mps')
5930
5931            for unbiased in [False, True]:
5932                for keepdim in [False, True]:
5933
5934                    zero_dim_var = x.var(-1, keepdim=keepdim, unbiased=unbiased)
5935                    zero_dim_var_cpu = cpu_x.var(-1, keepdim=keepdim, unbiased=unbiased)
5936
5937                    self.assertEqual(zero_dim_var, zero_dim_var_cpu)
5938
5939                    all_var = torch.var(x, unbiased=unbiased)
5940                    all_var_cpu = torch.var(cpu_x, unbiased=unbiased)
5941
5942                    self.assertEqual(all_var, all_var_cpu)
5943
5944                    nil_dim_var = torch.var(x, dim=[], keepdim=keepdim, unbiased=unbiased)
5945                    nil_dim_var_cpu = torch.var(cpu_x, dim=[], keepdim=keepdim, unbiased=unbiased)
5946
5947                    self.assertEqual(nil_dim_var, nil_dim_var_cpu)
5948
5949                    zero_dim_var = torch.var(x, dim=[0], keepdim=keepdim, unbiased=unbiased)
5950                    zero_dim_var_cpu = torch.var(cpu_x, dim=[0], keepdim=keepdim, unbiased=unbiased)
5951
5952                    self.assertEqual(zero_dim_var, zero_dim_var_cpu)
5953
5954                    zero_one_dim_var = torch.var(x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
5955                    zero_one_dim_var_cpu = torch.var(cpu_x, dim=[0, -1], keepdim=keepdim, unbiased=unbiased)
5956
5957                    self.assertEqual(zero_one_dim_var, zero_one_dim_var_cpu)
5958
5959                    two_three_dim_var = torch.var(x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
5960                    two_three_dim_var_cpu = torch.var(cpu_x, dim=[2, 3], keepdim=keepdim, unbiased=unbiased)
5961
5962                    self.assertEqual(two_three_dim_var, two_three_dim_var_cpu)
5963
5964        helper()
5965
5966    # Test forward amax
5967    def test_amax(self):
5968        def helper(shape, dim, keepdim):
5969            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5970            x = cpu_x.detach().clone().to('mps').requires_grad_()
5971
5972            result = torch.amax(x, dim=dim, keepdim=keepdim)
5973            result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
5974
5975            cpu_grad = torch.randn(result_cpu.shape)
5976            grad = cpu_grad.to('mps')
5977
5978            result_cpu.backward(gradient=cpu_grad)
5979            result.backward(gradient=grad)
5980
5981            self.assertEqual(result, result_cpu)
5982            self.assertEqual(x.grad, cpu_x.grad)
5983
5984        for dim in ([], [0], [0, 1], [2, 3]):
5985            for keepdim in [False, True]:
5986                helper((2, 8, 4, 5), dim, keepdim)
5987
5988    # Test forward amin
5989    def test_amin(self):
5990        def helper(shape, dim, keepdim):
5991            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
5992            x = cpu_x.detach().clone().to('mps').requires_grad_()
5993
5994            result = torch.amin(x, dim=dim, keepdim=keepdim)
5995            result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
5996
5997            cpu_grad = torch.randn(result_cpu.shape)
5998            grad = cpu_grad.to('mps')
5999
6000            result_cpu.backward(gradient=cpu_grad)
6001            result.backward(gradient=grad)
6002
6003            self.assertEqual(result, result_cpu)
6004            self.assertEqual(x.grad, cpu_x.grad)
6005
6006        for dim in ([], [0], [0, 1], [2, 3]):
6007            for keepdim in [False, True]:
6008                helper((2, 8, 4, 5), dim, keepdim)
6009
6010    # Test minimum and maximum
6011    def test_minimum_maximum(self):
6012        def helper(n, c, h, w):
6013            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6014            cpu_y = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6015            mps_x = cpu_x.detach().clone().to('mps')
6016            mps_y = cpu_y.detach().clone().to('mps')
6017
6018            minimum_result_cpu = torch.minimum(cpu_x, cpu_y)
6019            minimum_result_mps = torch.minimum(mps_x, mps_y)
6020            self.assertEqual(minimum_result_cpu, minimum_result_mps)
6021
6022            maximum_result_cpu = torch.maximum(cpu_x, cpu_y)
6023            maximum_result_mps = torch.maximum(mps_x, mps_y)
6024            self.assertEqual(maximum_result_cpu, maximum_result_mps)
6025
6026        helper(1, 1, 4, 5)
6027
6028    def test_clamp_fp16_fp32(self):
6029        cpu_x = torch.randn(10, device='cpu', dtype=torch.float, requires_grad=False)
6030        x = cpu_x.detach().clone().to('mps')
6031
6032        dtype = torch.float16
6033
6034        clamp_min_vals_mps = torch.ones(10, device="mps").to(torch.float16)
6035        clamp_max_vals_mps = torch.ones(10, device="mps").to(torch.float16) * 10
6036        clamp_result_mps = torch.clamp(x, clamp_min_vals_mps, clamp_max_vals_mps)
6037
6038        clamp_min_vals_cpu = torch.ones(10, device="cpu").to(torch.float16)
6039        clamp_max_vals_cpu = torch.ones(10, device="cpu").to(torch.float16) * 10
6040        clamp_result_cpu = torch.clamp(cpu_x, clamp_min_vals_cpu, clamp_max_vals_cpu)
6041
6042        self.assertEqual(clamp_result_mps, clamp_result_cpu)
6043
6044    def test_clamp_nan(self):
6045        t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
6046        t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu")
6047
6048        clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100)
6049        clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100)
6050
6051        self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu)
6052
6053        clamp_min_mps = torch.clamp(t_mps, min=-100)
6054        clamp_min_cpu = torch.clamp(t_cpu, min=-100)
6055
6056        self.assertEqual(clamp_min_mps, clamp_min_cpu)
6057
6058        clamp_max_mps = torch.clamp(t_mps, max=100)
6059        clamp_max_cpu = torch.clamp(t_cpu, max=100)
6060
6061        self.assertEqual(clamp_max_mps, clamp_max_cpu)
6062
6063    # Test clamp_min
6064    def test_clamp_min(self):
6065        def helper(n, c, h, w):
6066            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6067            x = cpu_x.detach().clone().to('mps')
6068
6069            cpu_min_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6070            min_t = cpu_min_t.detach().clone().to('mps')
6071
6072            clamp_min_result = torch.clamp_min(x, min=5.0)
6073            clamp_min_result_cpu = torch.clamp_min(cpu_x, min=5.0)
6074
6075            self.assertEqual(clamp_min_result, clamp_min_result_cpu)
6076
6077            clamp_min_t_result = torch.clamp_min(x, min=min_t)
6078            clamp_min_t_result_cpu = torch.clamp_min(cpu_x, min=cpu_min_t)
6079
6080            self.assertEqual(clamp_min_t_result, clamp_min_t_result_cpu)
6081
6082        helper(2, 8, 4, 5)
6083
6084    # Test clamp_max
6085
6086    def test_clamp_max(self):
6087        def helper(n, c, h, w):
6088            cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6089            x = cpu_x.detach().clone().to('mps')
6090
6091            cpu_max_t = torch.randn(n, c, h, w, device='cpu', dtype=torch.float, requires_grad=False)
6092            max_t = cpu_max_t.detach().clone().to('mps')
6093
6094            clamp_max_result = torch.clamp_max(x, max=100.0)
6095            clamp_max_result_cpu = torch.clamp_max(cpu_x, max=100.0)
6096
6097            self.assertEqual(clamp_max_result, clamp_max_result_cpu)
6098
6099            clamp_max_t_result = torch.clamp_max(x, max=max_t)
6100            clamp_max_t_result_cpu = torch.clamp_max(cpu_x, max=cpu_max_t)
6101
6102            self.assertEqual(clamp_max_t_result, clamp_max_t_result_cpu)
6103
6104        helper(2, 8, 4, 5)
6105
6106    # Test clamp
6107    def test_clamp(self):
6108        def helper(n, c, h, w):
6109            import numpy as np
6110            upper_bound = 1000
6111            half_upper_bound = upper_bound / 2
6112
6113            # x=[0..1000)
6114            x_arr = upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
6115            cpu_x = torch.tensor(x_arr, device='cpu', dtype=torch.float, requires_grad=False)
6116            x = cpu_x.detach().clone().to('mps')
6117
6118            # x=[0..500)
6119            min_arr = half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)
6120            cpu_min_t = torch.tensor(min_arr, device='cpu', dtype=torch.float, requires_grad=False)
6121            min_t = cpu_min_t.detach().clone().to('mps')
6122
6123            # x=[500..1000), to ensure max's are greater than mins
6124            max_arr = (half_upper_bound * np.random.random_sample(size=(n, c, h, w)).astype(np.float32)) + half_upper_bound
6125            cpu_max_t = torch.tensor(max_arr, device='cpu', dtype=torch.float, requires_grad=False)
6126            max_t = cpu_max_t.detach().clone().to('mps')
6127
6128            # [200..600]: just an arbitrary range between [0..1000]
6129            clamp_result = torch.clamp(x, min=200.0, max=600.0)
6130            clamp_result_cpu = torch.clamp(cpu_x, min=200.0, max=600.0)
6131            self.assertEqual(clamp_result, clamp_result_cpu)
6132
6133            # test optional scalar refs and cached graph keys by passing only max
6134            clamp_opt_result = torch.clamp(x, max=600.0)
6135            clamp_opt_result_cpu = torch.clamp(cpu_x, max=600.0)
6136            self.assertEqual(clamp_opt_result, clamp_opt_result_cpu)
6137
6138            clamp_t_result = torch.clamp(x, min=min_t, max=max_t)
6139            clamp_t_result_cpu = torch.clamp(cpu_x, min=cpu_min_t, max=cpu_max_t)
6140            self.assertEqual(clamp_t_result, clamp_t_result_cpu)
6141
6142            # test optional tensor refs and cached graph keys by passing only max
6143            clamp_topt_result = torch.clamp(x, max=max_t)
6144            clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
6145            self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
6146
6147            # test strided x
6148            clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
6149            clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
6150            self.assertEqual(clamp_result, clamp_result_cpu)
6151
6152            # test strided x, min_t, max_t
6153            clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
6154            clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
6155            self.assertEqual(clamp_result, clamp_result_cpu)
6156
6157            # test strided min_t, max_t
6158            clamp_result = torch.clamp(
6159                x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6160                min=min_t.movedim(0, -1),
6161                max=max_t.movedim(0, -1)
6162            )
6163            clamp_result_cpu = torch.clamp(
6164                cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
6165                min=cpu_min_t.movedim(0, -1),
6166                max=cpu_max_t.movedim(0, -1)
6167            )
6168            self.assertEqual(clamp_result, clamp_result_cpu)
6169
6170            # test inplace clamping
6171            x.clamp_(min=200.0, max=600.0)
6172            cpu_x.clamp_(min=200.0, max=600.0)
6173            self.assertEqual(cpu_x, x)
6174
6175        helper(2, 8, 4, 5)
6176
6177    def test_divmode(self):
6178        def helper(shape, rounding_mode):
6179            for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
6180                if ((rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) or
6181                        (rounding_mode is not None and "trunc" in rounding_mode and dtype == torch.float16)) is False:
6182                    cpu_x = None
6183                    cpu_y = None
6184                    if (dtype in [torch.float32, torch.float16]):
6185                        cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6186                        cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
6187                    else:
6188                        cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
6189                        cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
6190
6191                    mps_x = cpu_x.detach().clone().to('mps')
6192                    # clamp to avoid division by 0
6193                    mps_y = cpu_y.detach().clone().to('mps')
6194
6195                    if (rounding_mode == "floor_divide"):
6196                        result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
6197                        result_div_mps = torch.floor_divide(mps_x, mps_y)
6198                        self.assertEqual(result_div_mps, result_div_cpu)
6199                    else:
6200                        result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
6201                        result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
6202                        self.assertEqual(result_div_mps, result_div_cpu)
6203
6204        helper((2, 8, 4, 5), None)
6205        helper((2, 8, 4, 5), "floor")
6206        helper((2, 8, 4, 5), "trunc")
6207        helper((2, 8, 4, 5), "floor_divide")
6208
6209    def test_rounding(self):
6210        def helper(shape):
6211            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6212            mps_x = cpu_x.detach().clone().to('mps')
6213
6214            result_floor_cpu = torch.floor(cpu_x)
6215            result_floor_mps = torch.floor(mps_x)
6216            self.assertEqual(result_floor_mps, result_floor_cpu)
6217
6218            result_ceil_cpu = torch.ceil(cpu_x)
6219            result_ceil_mps = torch.ceil(mps_x)
6220            self.assertEqual(result_ceil_mps, result_ceil_cpu)
6221
6222            result_trunc_cpu = torch.trunc(cpu_x)
6223            result_trunc_mps = torch.trunc(mps_x)
6224            self.assertEqual(result_trunc_mps, result_trunc_cpu)
6225
6226            result_round_cpu = torch.round(cpu_x)
6227            result_round_mps = torch.round(mps_x)
6228            self.assertEqual(result_round_mps, result_round_cpu)
6229
6230        helper((2, 6, 3, 5))
6231        helper((2, 8, 4, 5))
6232
6233    def test_remainder(self):
6234        res_cpu = torch.remainder(
6235            torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="cpu"), torch.tensor(2, device="cpu", dtype=torch.int32))
6236        res_mps = torch.remainder(
6237            torch.tensor([-3, -2, -1, 1, 2, 3], dtype=torch.int32, device="mps"), torch.tensor(2, device="mps", dtype=torch.int32))
6238        self.assertEqual(res_cpu, res_mps)
6239
6240        res_cpu = torch.remainder(
6241            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="cpu"), -1.5)
6242        res_mps = torch.remainder(
6243            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32, device="mps"), -1.5)
6244        self.assertEqual(res_cpu, res_mps)
6245
6246    def test_expand(self):
6247        def helper(n, c):
6248            values = [[1.0], [4.0], [7.0]]
6249            cpu_x = torch.tensor(values, device='cpu')
6250            x = cpu_x.detach().clone().to('mps')
6251
6252            strided_cpu = torch.as_strided(cpu_x, (3, 4), (1, 0))
6253            strided_mps = torch.as_strided(x, (3, 4), (1, 0))
6254
6255            self.assertEqual(strided_mps, strided_cpu)
6256
6257        helper(3, 1)
6258
6259    def test_im2col(self):
6260        def helper(x):
6261            return torch.nn.functional.unfold(x, kernel_size=(10, 15), dilation=2, padding=5, stride=3)
6262        x_cpu = torch.rand(1, 1, 200, 100)
6263        x = x_cpu.detach().clone().to('mps')
6264        self.assertEqual(helper(x_cpu), helper(x))
6265
6266    def test_select(self):
6267        def helper(n, c):
6268            cpu_x = torch.randn(n, c, device='cpu', dtype=torch.float, requires_grad=True)
6269            x = cpu_x.detach().clone().to('mps').requires_grad_()
6270
6271            strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1))
6272            strided_mps = torch.as_strided(x, (3, 1), (3, 1))
6273            self.assertEqual(strided_mps, strided_cpu)
6274
6275            strided_cpu = torch.as_strided(cpu_x, (1, 3), (3, 1))
6276            strided_mps = torch.as_strided(x, (1, 3), (3, 1))
6277            self.assertEqual(strided_mps, strided_cpu)
6278
6279            strided_cpu = torch.as_strided(cpu_x, (3, 1), (3, 1), storage_offset=1)
6280            strided_mps = torch.as_strided(x, (3, 1), (3, 1), storage_offset=1)
6281
6282            self.assertEqual(strided_mps, strided_cpu)
6283
6284        helper(3, 3)
6285
6286    def test_sort(self):
6287        for SIZE in (4, 2049):
6288            device = 'mps'
6289            x = torch.rand(4, SIZE, device=device)
6290            res1val, res1ind = torch.sort(x)
6291
6292            res2val = torch.tensor((), device=device)
6293            res2ind = torch.tensor((), device=device, dtype=torch.long)
6294            torch.sort(x, out=(res2val, res2ind))
6295            self.assertEqual(res1val, res2val, atol=0, rtol=0)
6296            self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
6297            self.assertEqual(torch.argsort(x), res1ind)
6298            self.assertEqual(x.argsort(), res1ind)
6299
6300            self.assertEqual(
6301                torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
6302                torch.tensor((10, 20, 30, 40, 50), device=device),
6303                atol=0, rtol=0
6304            )
6305
6306    def test_upsample_nearest2d(self):
6307        def helper(N, C, H, W, memory_format):
6308            inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
6309                                    requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
6310            inputCPU.retain_grad()
6311            inputMPS = inputCPU.detach().to('mps').requires_grad_()
6312
6313            values = [1, 2, 5, 10, 40]
6314
6315            for i in values:
6316                for j in values:
6317                    upsample_nearest2d = nn.UpsamplingNearest2d(scale_factor=(i, j))
6318
6319                    outputCPU = upsample_nearest2d(inputCPU)
6320                    outputMPS = upsample_nearest2d(inputMPS)
6321
6322                    self.assertEqual(outputCPU, outputMPS)
6323                    upsample_nearest2d = nn.UpsamplingNearest2d((i * H, j * W))
6324
6325                    outputCPU = upsample_nearest2d(inputCPU)
6326                    outputMPS = upsample_nearest2d(inputMPS)
6327
6328                    self.assertEqual(outputCPU, outputMPS)
6329
6330                    outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6331                    outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6332
6333                    self.assertEqual(inputCPU.grad, inputMPS.grad)
6334
6335        for memory_format in [torch.channels_last, torch.contiguous_format]:
6336            helper(1, 1, 4, 4, memory_format=memory_format)
6337            helper(7, 5, 3, 2, memory_format=memory_format)
6338
6339    def test_upsample_bilinear2d(self):
6340        def helper(N, C, H, W):
6341            inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
6342                                    requires_grad=True).reshape(N, C, H, W)
6343            inputCPU.retain_grad()
6344            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6345
6346            values = [1, 2, 5, 10, 40]
6347
6348            for i in values:
6349                for j in values:
6350                    upsample_bilinear2d = nn.UpsamplingBilinear2d(scale_factor=(i, j))
6351
6352                    outputCPU = upsample_bilinear2d(inputCPU)
6353                    outputMPS = upsample_bilinear2d(inputMPS)
6354
6355                    self.assertEqual(outputCPU, outputMPS)
6356
6357                    upsample_bilinear2d = nn.UpsamplingBilinear2d((i * H, j * W))
6358
6359                    outputCPU = upsample_bilinear2d(inputCPU)
6360                    outputMPS = upsample_bilinear2d(inputMPS)
6361
6362                    self.assertEqual(outputCPU, outputMPS)
6363
6364                    outputCPU.backward(gradient=torch.full_like(outputCPU, 0.3))
6365                    outputMPS.backward(gradient=torch.full_like(outputMPS, 0.3))
6366
6367                    self.assertEqual(inputCPU.grad, inputMPS.grad)
6368
6369        helper(1, 1, 4, 4)
6370        helper(7, 5, 3, 2)
6371
6372    def test_interpolate(self):
6373        def helper(shape, output_size, scales, mode, align_corners=False):
6374            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6375            inputCPU.retain_grad()
6376            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
6377
6378            # align_corners is used for 2D interpolation only
6379            if (align_corners is True and len(shape) > 3 and mode == 'bilinear'):
6380                if scales is not None:
6381                    outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode, align_corners=align_corners)
6382                    outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode, align_corners=align_corners)
6383                else:
6384                    outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode, align_corners=align_corners)
6385                    outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode, align_corners=align_corners)
6386            elif scales is not None:
6387                outputCPU = nn.functional.interpolate(inputCPU, scale_factor=scales, mode=mode)
6388                outputMPS = nn.functional.interpolate(inputMPS, scale_factor=scales, mode=mode)
6389            else:
6390                outputCPU = nn.functional.interpolate(inputCPU, size=output_size, mode=mode)
6391                outputMPS = nn.functional.interpolate(inputMPS, size=output_size, mode=mode)
6392
6393            self.assertEqual(outputCPU, outputMPS)
6394
6395            # backward pass (chose 0.6 just to have the grad_output != 1)
6396            outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
6397            outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
6398            self.assertEqual(inputCPU.grad, inputMPS.grad)
6399
6400        # 1D interpolation
6401        for mode in ['nearest', 'nearest-exact']:
6402            helper([2, 3, 4], [3], None, mode)  # downsample with size
6403            helper([2, 3, 4], [6], None, mode)  # upsample with size
6404            helper([2, 3, 4], None, [0.6], mode)  # downsample with scale factor
6405            helper([2, 3, 4], None, [1.7], mode)  # upsample with scale factor
6406        # 2D interpolation
6407        for mode in ['nearest', 'nearest-exact', 'bilinear']:
6408            helper([2, 3, 4, 5], [3, 4], None, mode)  # downsample_nearest with size
6409            helper([2, 3, 4, 5], [6, 7], None, mode)  # upsample_nearest with size
6410            helper([2, 3, 4, 5], None, [0.6, 0.7], mode)  # downsample_nearest with scale factor
6411            helper([2, 3, 4, 5], None, [1.4, 1.7], mode)  # upsample_nearest with scale factor
6412        # align_corners=True
6413        helper([2, 3, 4, 5], [3, 4], None, 'bilinear', True)
6414        helper([2, 3, 4, 5], None, [1.4, 1.7], 'bilinear', True)
6415
6416    # Test concat forward
6417    def test_cat1(self):
6418        def helper(shape_x, shape_y, shape_z):
6419            cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6420            x = cpu_x.detach().clone().to('mps')
6421
6422            cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6423            y = cpu_y.detach().clone().to('mps')
6424
6425            cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6426            z = cpu_z.detach().clone().to('mps')
6427
6428            cat = torch.cat([x, y, z], dim=1)
6429            cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6430
6431            self.assertEqual(cat, cat_cpu)
6432
6433        helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
6434        helper([2, 2, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6435        helper([0, 2, 4, 5], [0, 3, 4, 5], [0, 5, 4, 5])
6436        helper([2, 2, 6, 5], [0], [2, 5, 6, 5])
6437        helper([0], [2, 3, 6, 5], [2, 5, 6, 5])
6438        helper([2, 3, 4, 5], [2, 5, 4, 5], [0])
6439        helper([2, 2, 6, 5], [2, 0, 6, 5], [2, 5, 6, 5])
6440        helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 5, 6, 5])
6441        helper([2, 0, 6, 5], [2, 3, 6, 5], [2, 0, 6, 5])
6442
6443    # Test stack forward
6444    def test_stack(self):
6445        # All shapes must be same
6446        def helper(shape, dtype=torch.float32):
6447
6448            x, cpu_x = None, None
6449            y, cpu_y = None, None
6450            z, cpu_z = None, None
6451
6452            if (dtype not in [torch.float32, torch.bool]):
6453                cpu_x = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6454                x = cpu_x.detach().clone().to('mps')
6455                cpu_y = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6456                y = cpu_y.detach().clone().to('mps')
6457                cpu_z = torch.randint(50, shape, device='cpu', dtype=dtype, requires_grad=False)
6458                z = cpu_z.detach().clone().to('mps')
6459            elif (dtype == torch.bool):
6460                cpu_x = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6461                x = cpu_x.detach().clone().to('mps')
6462                cpu_y = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6463                y = cpu_y.detach().clone().to('mps')
6464                cpu_z = torch.randint(2, shape, device='cpu', dtype=dtype, requires_grad=False)
6465                z = cpu_z.detach().clone().to('mps')
6466            else:
6467                cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6468                x = cpu_x.detach().clone().to('mps').requires_grad_()
6469                cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6470                y = cpu_y.detach().clone().to('mps').requires_grad_()
6471                cpu_z = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6472                z = cpu_z.detach().clone().to('mps').requires_grad_()
6473
6474            stack = torch.stack([x, y, z], dim=1)
6475            stack_cpu = torch.stack([cpu_x, cpu_y, cpu_z], dim=1)
6476
6477            self.assertEqual(stack, stack_cpu)
6478
6479        helper([2, 8, 4, 5])
6480        helper([2, 8, 4, 5], dtype=torch.float16)
6481        helper([2, 8, 4, 5], dtype=torch.int32)
6482        helper([2, 8, 4, 5], dtype=torch.int64)
6483        helper([2, 8, 4, 5], dtype=torch.bool)
6484        # Empty test - Currently failing! Empty tensor not handled!
6485        # helper([0, 2, 4, 5])
6486
6487    # Test abs
6488    def test_abs(self):
6489        def helper(shape):
6490            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6491            x = cpu_x.detach().clone().to('mps')
6492
6493            abs_result = torch.abs(x)
6494            abs_result_cpu = torch.abs(cpu_x)
6495
6496            self.assertEqual(abs_result, abs_result_cpu)
6497
6498        helper((2, 8, 4, 5))
6499
6500    def test_log(self):
6501        def helper(shape):
6502            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6503            x = cpu_x.detach().clone().to('mps')
6504
6505            log_result = torch.log(x)
6506            log_result_cpu = torch.log(cpu_x)
6507
6508            self.assertEqual(log_result, log_result_cpu)
6509
6510        helper((2, 8, 4, 5))
6511
6512    def test_log_ten(self):
6513        def helper(shape):
6514            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6515            x = cpu_x.detach().clone().to('mps')
6516
6517            log_ten_result = torch.log10(x)
6518            log_ten_result_cpu = torch.log10(cpu_x)
6519
6520            self.assertEqual(log_ten_result, log_ten_result_cpu)
6521
6522        helper((2, 8, 4, 5))
6523
6524    def test_log_two(self):
6525        def helper(shape):
6526            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6527            x = cpu_x.detach().clone().to('mps')
6528
6529            log_two_result = torch.log2(x)
6530            log_two_result_cpu = torch.log2(cpu_x)
6531
6532            self.assertEqual(log_two_result, log_two_result_cpu)
6533
6534        helper((2, 8, 4, 5))
6535
6536    def test_log1p(self):
6537        def helper(shape):
6538            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6539            x = cpu_x.detach().clone().to('mps')
6540
6541            log_result = torch.log1p(x)
6542            log_result_cpu = torch.log1p(cpu_x)
6543
6544            self.assertEqual(log_result, log_result_cpu)
6545
6546        helper((2, 8, 4, 5))
6547
6548    def test_logaddexp(self):
6549        def helper(shape):
6550            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6551            x = cpu_x.detach().clone().to('mps')
6552
6553            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6554            y = cpu_y.detach().clone().to('mps')
6555
6556            log_result = torch.logaddexp(x, y)
6557            log_result_cpu = torch.logaddexp(cpu_x, cpu_y)
6558
6559            self.assertEqual(log_result, log_result_cpu)
6560
6561        helper((2, 8, 4, 5))
6562
6563    def test_logaddexp2(self):
6564        def helper(shape):
6565            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6566            x = cpu_x.detach().clone().to('mps')
6567
6568            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6569            y = cpu_y.detach().clone().to('mps')
6570
6571            log_result = torch.logaddexp2(x, y)
6572            log_result_cpu = torch.logaddexp2(cpu_x, cpu_y)
6573
6574            self.assertEqual(log_result, log_result_cpu)
6575
6576        helper((2, 8, 4, 5))
6577
6578    def test_logsumexp(self):
6579        def helper(shape):
6580            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6581            x = cpu_x.detach().clone().to('mps')
6582
6583            log_result = torch.logsumexp(x, -1)
6584            log_result_cpu = torch.logsumexp(cpu_x, -1)
6585
6586            self.assertEqual(log_result, log_result_cpu)
6587
6588        helper((2, 8, 4, 5))
6589
6590    # Test concat forward
6591    def test_cat2(self):
6592
6593        def helper1(shape_x, shape_y, shape_z, shape_w):
6594            cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6595            x = cpu_x.detach().clone().to('mps')
6596
6597            cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6598            y = cpu_y.detach().clone().to('mps')
6599
6600            cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6601            z = cpu_z.detach().clone().to('mps')
6602
6603            cpu_w = torch.randn(shape_w, device='cpu', dtype=torch.float, requires_grad=False)
6604            w = cpu_w.detach().clone().to('mps')
6605
6606            cat = torch.cat([x, y, z, w], dim=1)
6607            cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z, cpu_w], dim=1)
6608
6609            self.assertEqual(cat, cat_cpu)
6610
6611        def helper(shape_x, shape_y, shape_z):
6612            cpu_x = torch.randn(shape_x, device='cpu', dtype=torch.float, requires_grad=False)
6613            x = cpu_x.detach().clone().to('mps')
6614
6615            cpu_y = torch.randn(shape_y, device='cpu', dtype=torch.float, requires_grad=False)
6616            y = cpu_y.detach().clone().to('mps')
6617
6618            cpu_z = torch.randn(shape_z, device='cpu', dtype=torch.float, requires_grad=False)
6619            z = cpu_z.detach().clone().to('mps')
6620
6621            cat = torch.cat([x, y, z], dim=1)
6622            cat_cpu = torch.cat([cpu_x, cpu_y, cpu_z], dim=1)
6623
6624            self.assertEqual(cat, cat_cpu)
6625
6626        helper([2, 8, 4, 5], [2, 10, 4, 5], [2, 6, 4, 5])
6627        helper([2, 2, 4, 5], [2, 3, 4, 5], [2, 5, 4, 5])
6628        # Empty test - Currently failing! Empty tensor not handled!
6629        # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
6630
6631    # Test isnan
6632    def test_isnan(self):
6633        def helper(shape):
6634            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
6635            nan_index = [random.randrange(0, shape[0])]
6636            # make a selected row inf
6637            cpu_x.index_put_(indices=[torch.tensor(nan_index)], values=torch.tensor(float('nan')))
6638            x = cpu_x.detach().clone().to('mps')
6639
6640            isnan_result = torch.isnan(x)
6641            isnan_result_cpu = torch.isnan(cpu_x)
6642
6643            self.assertEqual(isnan_result, isnan_result_cpu)
6644
6645        helper((8, 2, 4, 5))
6646
6647    # Test reciprocal
6648    def test_reciprocal(self):
6649        def helper(shape):
6650            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6651            x = cpu_x.detach().clone().to('mps').requires_grad_()
6652
6653            reciprocal_result = torch.reciprocal(x)
6654            reciprocal_result_cpu = torch.reciprocal(cpu_x)
6655
6656            cpu_grad = torch.ones_like(reciprocal_result_cpu)
6657            grad = cpu_grad.to('mps')
6658
6659            reciprocal_result.backward(gradient=grad)
6660            reciprocal_result_cpu.backward(gradient=cpu_grad)
6661
6662            self.assertEqual(reciprocal_result, reciprocal_result_cpu)
6663            self.assertEqual(x.grad, cpu_x.grad)
6664
6665        helper((2, 8, 4, 5))
6666
6667    # Test sqrt
6668    def test_sqrt(self):
6669        def helper(shape):
6670            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6671            x = cpu_x.detach().clone().to('mps').requires_grad_()
6672
6673            sqrt_result = torch.sqrt(x)
6674            sqrt_result_cpu = torch.sqrt(cpu_x)
6675
6676            cpu_grad = torch.ones_like(sqrt_result_cpu)
6677            grad = cpu_grad.to('mps')
6678
6679            sqrt_result.backward(gradient=grad)
6680            sqrt_result_cpu.backward(gradient=cpu_grad)
6681
6682            self.assertEqual(sqrt_result, sqrt_result_cpu)
6683            self.assertEqual(x.grad, cpu_x.grad)
6684
6685        helper((2, 8, 4, 5))
6686
6687    # Test selu, elu, celu
6688    def test_elu(self):
6689        def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
6690            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
6691            cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
6692
6693            x = cpu_x.detach().clone().to('mps').requires_grad_(True)
6694            for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
6695                elu_result = activation_func(x)
6696                elu_result_cpu = activation_func(cpu_x)
6697
6698                cpu_grad = torch.randn(elu_result_cpu.shape)
6699                grad = cpu_grad.to('mps')
6700
6701                elu_result.backward(gradient=grad)
6702                elu_result_cpu.backward(gradient=cpu_grad)
6703
6704                self.assertEqual(elu_result, elu_result_cpu)
6705                self.assertEqual(x.grad, cpu_x.grad)
6706
6707        # Test empty shape too
6708        for memory_fromat in [torch.channels_last, torch.contiguous_format]:
6709            for shape in [(2, 8, 4, 5)]:
6710                for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
6711                    helper(shape, alpha, memory_fromat)
6712
6713    def test_elu_strided_output(self):
6714        # https://github.com/pytorch/pytorch/issues/124834
6715        elu_input = torch.randn(1, 1024, 500)
6716        alpha = float(1)
6717        inplace = False
6718
6719        elu_input_noncontiguous = elu_input.transpose(1, 2)
6720        self.assertEqual(
6721            F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace),
6722            F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
6723        )
6724
6725    # Test glu
6726    def test_glu(self):
6727        def helper(shape, dim=0):
6728            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6729            x = cpu_x.detach().clone().to('mps').requires_grad_()
6730
6731            for activation_func in [torch.nn.GLU(dim=dim)]:
6732                glu_result = activation_func(x)
6733                glu_result_cpu = activation_func(cpu_x)
6734
6735                cpu_grad = torch.randn(glu_result_cpu.shape)
6736                grad = cpu_grad.to('mps')
6737
6738                glu_result.backward(gradient=grad)
6739                glu_result_cpu.backward(gradient=cpu_grad)
6740
6741                self.assertEqual(glu_result, glu_result_cpu)
6742                self.assertEqual(x.grad, cpu_x.grad)
6743
6744        for shape in [[4], (2, 4), (2, 8, 4, 6)]:
6745            for dim in range(len(shape)):
6746                helper(shape, dim)
6747
6748    # Test softplus
6749    def test_softplus(self):
6750        def helper(shape, beta, threshold, dtype):
6751            cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
6752            x = cpu_x.detach().clone().to('mps').requires_grad_()
6753
6754            softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
6755            softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x)
6756
6757            cpu_grad = torch.randn(softplus_result.shape)
6758            grad = cpu_grad.to('mps')
6759
6760            softplus_result.backward(gradient=grad)
6761            softplus_result_cpu.backward(gradient=cpu_grad)
6762
6763            self.assertEqual(softplus_result, softplus_result_cpu)
6764            self.assertEqual(x.grad, cpu_x.grad)
6765
6766        # Test empty shape too
6767        for shape, beta, threshold, dtype in product(
6768            [(), (2, 3), (10, 10), (2, 3, 4, 5)],
6769            [0.5, 1, 2, 3, 4],
6770            [0.5, 20, 30, 40, 50],
6771            [torch.float16, torch.float32]
6772        ):
6773            helper(shape, beta, threshold, dtype)
6774
6775    # Test silu
6776
6777    def test_silu(self):
6778        def helper(shape):
6779            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
6780            x = cpu_x.detach().clone().to('mps').requires_grad_()
6781
6782            silu_result = torch.nn.SiLU()(x)
6783            silu_result_cpu = torch.nn.SiLU()(cpu_x)
6784
6785            cpu_grad = torch.randn(silu_result_cpu.shape)
6786            grad = cpu_grad.to('mps')
6787
6788            silu_result.backward(gradient=grad)
6789            silu_result_cpu.backward(gradient=cpu_grad)
6790
6791            self.assertEqual(silu_result, silu_result_cpu)
6792            self.assertEqual(x.grad, cpu_x.grad)
6793
6794        # Test empty shape too
6795        for shape in [[], (2, 3), (2, 8, 4, 5)]:
6796            helper(shape)
6797
6798    def test_cast_mps_to_cpu(self):
6799        def helper(src_dtype, dst_dtype):
6800            input = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6801            input_cast_mps = input.to('mps')
6802            input_cast_cpu = input_cast_mps.to('cpu', dtype=dst_dtype)
6803
6804            # needs to match the initial Tensor
6805            self.assertEqual(input_cast_cpu, input.to(dtype=dst_dtype))
6806        helper(torch.half, torch.float)
6807        helper(torch.float, torch.half)
6808
6809    def test_cast_mps_to_mps(self):
6810        def helper(src_dtype, dst_dtype):
6811            input_cpu = torch.rand((1, 3, 128, 128), dtype=src_dtype)
6812            input_mps = input_cpu.to('mps')
6813            output_mps = input_mps.to(dtype=dst_dtype)
6814            output_cpu = input_cpu.to(dtype=dst_dtype)
6815            self.assertEqual(output_mps.cpu(), output_cpu)
6816        helper(torch.half, torch.float)
6817        helper(torch.float, torch.half)
6818        helper(torch.half, torch.long)
6819        helper(torch.float, torch.int)
6820
6821    def test_avg_pool2d_count_include_pad(self):
6822        cpu_x = torch.randn((1, 3, 9, 9), device='cpu', dtype=torch.float, requires_grad=True)
6823        x = cpu_x.detach().clone().to('mps').requires_grad_()
6824        pool = torch.nn.AvgPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), ceil_mode=True, count_include_pad=True)
6825        ref_y = pool(cpu_x)
6826        y = pool(x)
6827        self.assertEqual(y, ref_y)
6828        cpu_grad = torch.randn(ref_y.shape)
6829        grad = cpu_grad.to('mps')
6830        ref_y.backward(gradient=cpu_grad)
6831        y.backward(gradient=grad)
6832        self.assertEqual(x.grad, cpu_x.grad)
6833
6834    # Test adaptive avg pool2d - when the input size is a multiple of output size
6835    # Not testing for channels last right now
6836    def test_adaptive_avg_pool2d_simple(self):
6837        def helper(input_shape, out_shape, channels_last):
6838            cpu_x = torch.randn(input_shape, device='cpu', dtype=torch.float, requires_grad=True)
6839            if (channels_last):
6840                cpu_x = cpu_x.to(memory_format=torch.channels_last)
6841                cpu_x.retain_grad()
6842            x = cpu_x.detach().clone().to('mps').requires_grad_()
6843
6844            avg_result = torch.nn.AdaptiveAvgPool2d(out_shape)(x)
6845            avg_result_cpu = torch.nn.AdaptiveAvgPool2d(out_shape)(cpu_x)
6846
6847            cpu_grad = torch.randn(avg_result_cpu.shape)
6848            grad = cpu_grad.to('mps')
6849
6850            avg_result.backward(gradient=grad)
6851            avg_result_cpu.backward(gradient=cpu_grad)
6852
6853            self.assertEqual(avg_result, avg_result_cpu)
6854            self.assertEqual(x.grad, cpu_x.grad)
6855
6856        helper((2, 2, 4, 4), (2, 2), False)
6857        helper((2, 2, 9, 9), (3, 3), False)
6858        helper((2, 2, 9, 9), (9, 9), False)
6859        helper((2, 2, 16, 16), (2, 2), False)
6860        helper((2, 2, 16, 16), (2, 16), False)
6861
6862        helper((2, 16, 16), (4, 4), False)
6863
6864        # Output shape larger than input shape
6865
6866        helper((2, 2, 4, 4), (8, 8), False)
6867        helper((2, 2, 2, 2), (4, 4), False)
6868        helper((2, 2, 3, 3), (9, 9), False)
6869        helper((2, 2, 2, 2), (16, 16), False)
6870        helper((2, 2, 2, 16), (16, 16), False)
6871
6872        helper((2, 4, 4), (16, 16), False)
6873
6874        try:
6875            helper((2, 2, 3, 3), (7, 7), False)
6876        except Exception as e:
6877            pass
6878
6879    # Test max avg pool2d - when the input size is a multiple of output size
6880    # Not testing for channels last right now
6881    def test_adaptive_max_pool2d_simple(self):
6882        def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
6883            cpu_x = None
6884            if (dtype in [torch.float16, torch.float32]):
6885                cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
6886            else:
6887                cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
6888            if (channels_last):
6889                cpu_x = cpu_x.to(memory_format=torch.channels_last)
6890                cpu_x.retain_grad()
6891            x = cpu_x.detach().clone().to('mps').requires_grad_()
6892
6893            max_result, max_indices = None, None
6894            max_result_cpu, max_indices_cpu = None, None
6895
6896            if (return_indices):
6897                max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6898                max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6899            else:
6900                max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
6901                max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
6902
6903            cpu_grad = torch.randn(max_result_cpu.shape)
6904            grad = cpu_grad.to('mps')
6905
6906            max_result.backward(gradient=grad)
6907            max_result_cpu.backward(gradient=cpu_grad)
6908
6909            self.assertEqual(max_result, max_result_cpu)
6910            if (return_indices):
6911                self.assertEqual(max_indices, max_indices_cpu)
6912            self.assertEqual(x.grad, cpu_x.grad)
6913
6914        for dtype in [torch.float32]:
6915            for return_indices in [False, True]:
6916                helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
6917                helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
6918                helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
6919                helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
6920                helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
6921                helper((2, 16, 16), (4, 4), return_indices, dtype)
6922
6923    def test_gelu_simple(self):
6924        def helper(shape, dtype=torch.float, contiguous=True):
6925            cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6926            x = cpu_x.detach().clone().to('mps')
6927
6928            if not contiguous and (0 not in shape and len(shape) >= 2):
6929                # Tranposing will make the tensor non-contiguous
6930                cpu_x = cpu_x.transpose(0, 1)
6931                x = x.transpose(0, 1)
6932                assert not x.is_contiguous()
6933
6934            cpu_x.requires_grad_()
6935            x.requires_grad_()
6936
6937            gelu_result = torch.nn.GELU()(x)
6938            # GELU is not supported on CPU, so cast it to float
6939            gelu_result_cpu = torch.nn.GELU()(cpu_x.to(torch.float))
6940
6941            cpu_grad = torch.ones_like(gelu_result_cpu)
6942            grad = cpu_grad.to('mps')
6943
6944            gelu_result.backward(gradient=grad)
6945            gelu_result_cpu.backward(gradient=cpu_grad)
6946
6947            atol = 1e-5 if dtype == torch.float else 1e-2
6948            rtol = 1e-3 if dtype == torch.float else 1e-2
6949            self.assertEqual(gelu_result, gelu_result_cpu.to(dtype), atol=atol, rtol=rtol)
6950
6951            assert x.grad is not None  # Check that the grad is well-populated
6952            self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
6953
6954        # Test empty shape too
6955        for dtype in [torch.float, torch.half]:
6956            for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6957                for contiguous in [True, False]:
6958                    helper(shape, dtype, contiguous)
6959        # Test that gelu would raise an assert for integral types
6960        for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
6961            self.assertRaises(RuntimeError, lambda: torch.nn.GELU()(torch.randint(100, (2,), dtype=dtype, device="mps")))
6962
6963    def test_mish_simple(self):
6964        def helper(shape, dtype=torch.float, contiguous=True):
6965            cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
6966            x = cpu_x.detach().clone().to('mps')
6967
6968            if not contiguous and (0 not in shape and len(shape) >= 2):
6969                # Tranposing will make the tensor non-contiguous
6970                cpu_x = cpu_x.transpose(0, 1)
6971                x = x.transpose(0, 1)
6972                assert not x.is_contiguous()
6973
6974            cpu_x.requires_grad_()
6975            x.requires_grad_()
6976
6977            mish_result = torch.nn.Mish()(x)
6978            mish_result_cpu = torch.nn.Mish()(cpu_x)
6979
6980            cpu_grad = torch.ones_like(mish_result_cpu)
6981            grad = cpu_grad.to('mps')
6982
6983            mish_result.backward(gradient=grad)
6984            mish_result_cpu.backward(gradient=cpu_grad)
6985
6986            atol = 1e-5 if dtype == torch.float else 1e-2
6987            rtol = 1e-3 if dtype == torch.float else 1e-2
6988            self.assertEqual(mish_result, mish_result_cpu.to(dtype), atol=atol, rtol=rtol)
6989
6990            assert x.grad is not None  # Check that the grad is well-populated
6991            self.assertEqual(x.grad, cpu_x.grad, atol=atol, rtol=rtol)
6992
6993        # Test empty shape too
6994        for dtype in [torch.float, torch.half]:
6995            for shape in [[], (0,), (0, 3), (4,), (4, 3), (5, 4, 3)]:
6996                for contiguous in [True, False]:
6997                    helper(shape, dtype, contiguous)
6998
6999    def test_gelu(self):
7000        def _test_gelu(n, m, dtype, contiguous, atol=None, rtol=None):
7001            numpy_dtype = {
7002                torch.bfloat16: torch.float, torch.float: torch.float, torch.double: torch.double
7003            }[dtype]
7004            devices = ['cpu']
7005            devices += ['mps']
7006
7007            def _gelu_ref(X):
7008                return X * stats.norm.cdf(X)  # noqa: F821
7009
7010            for d in devices:
7011                X = torch.rand(n, m, dtype=dtype, requires_grad=True, device=d)[:, ::2]
7012                res = X
7013                ref = (X.to(numpy_dtype).cpu().detach().numpy())
7014                self.assertEqual(res, ref, rtol=rtol, atol=atol, exact_dtype=False)
7015
7016        for n in [1, 5, 10]:
7017            for m in [1, 5, 10]:
7018                _test_gelu(n, m, torch.float32, True)
7019                _test_gelu(n, m, torch.float32, False)
7020
7021        # Test multi threaded
7022        num_threads = torch.get_num_threads()
7023        torch.set_num_threads(4)
7024        try:
7025            _test_gelu(32, 32, torch.float32, False)
7026        finally:
7027            torch.set_num_threads(num_threads)
7028
7029    def test_gelu_tanh(self):
7030        def helper(shape):
7031            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
7032            x = cpu_x.detach().clone().to('mps')
7033
7034            gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
7035            gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
7036            self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
7037
7038        helper((2, 8, 4, 5))
7039
7040    # Test hardtanh
7041    def test_hardtanh(self):
7042        def helper(shape, min_val, max_val, inplace=False):
7043            cpu_x = None
7044            x = None
7045
7046            if (not inplace):
7047                cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7048                x = cpu_x.detach().clone().to('mps').requires_grad_()
7049            else:
7050                cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7051                x = cpu_x.detach().clone().to('mps')
7052
7053            hardtanh_result = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(x)
7054            hardtanh_result_cpu = torch.nn.Hardtanh(min_val=min_val, max_val=max_val, inplace=inplace)(cpu_x)
7055
7056            self.assertEqual(hardtanh_result, hardtanh_result_cpu)
7057
7058            if (not inplace):
7059                cpu_grad = torch.randn(hardtanh_result_cpu.shape)
7060                grad = cpu_grad.to('mps')
7061                hardtanh_result.backward(gradient=grad)
7062                hardtanh_result_cpu.backward(gradient=cpu_grad)
7063                self.assertEqual(x.grad, cpu_x.grad)
7064
7065        # Test empty shape too
7066        for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
7067            for min_val, max_val in zip([-1, -2, 3], [1, -1, 4]):
7068                helper(shape, min_val, max_val)
7069                helper(shape, min_val, max_val, inplace=True)
7070
7071    def test_hardswish(self):
7072        def helper(shape, inplace=False, requires_grad=True):
7073            m = nn.Hardswish(inplace=inplace)
7074
7075            input_cpu = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=requires_grad)
7076            input_mps = input_cpu.detach().clone().to('mps').requires_grad_(requires_grad)
7077
7078            if inplace and requires_grad:  # check that both raise runtime error
7079                self.assertRaises(RuntimeError, lambda: m(input_cpu))
7080                self.assertRaises(RuntimeError, lambda: m(input_mps))
7081                return
7082
7083            output_cpu = m(input_cpu)
7084            output_mps = m(input_mps)
7085
7086            cpu_grad = torch.ones_like(output_cpu)
7087            mps_grad = cpu_grad.to('mps')
7088
7089            self.assertEqual(output_cpu, output_mps)
7090
7091            if requires_grad:
7092                output_cpu.backward(gradient=cpu_grad)
7093                output_mps.backward(gradient=mps_grad)
7094
7095                self.assertEqual(input_cpu.grad, input_mps.grad)
7096
7097        for shape in [(0, 3), [], (2, 3), (2, 8, 4, 5)]:
7098            helper(shape, inplace=False, requires_grad=False)
7099            helper(shape, inplace=True, requires_grad=False)
7100            helper(shape, inplace=False, requires_grad=True)
7101            helper(shape, inplace=True, requires_grad=True)
7102
7103    def test_transpose_2D(self):
7104        values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
7105        values1 = [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
7106        cpu_x = torch.tensor(values, device='cpu')
7107        mps_x = torch.tensor(values, device='mps')
7108        mps_x1 = torch.tensor(values1, device='mps')
7109
7110        cpu_transpose = torch.transpose(cpu_x, 0, 1)
7111        mps_transpose = torch.transpose(mps_x, 0, 1)
7112        self.assertEqual(cpu_transpose, mps_transpose.to('cpu'))
7113
7114    def test_transpose_3D(self):
7115        values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
7116        cpu_x = torch.tensor(values, device='cpu')
7117        mps_x = torch.tensor(values, device='mps')
7118
7119        cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
7120        mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
7121        self.assertEqual(cpu_transpose1, mps_transpose1)
7122
7123        cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
7124        mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
7125        self.assertEqual(cpu_transpose2, mps_transpose2)
7126
7127        cpu_transpose3 = torch.transpose(cpu_x, 1, 2)
7128        mps_transpose3 = torch.transpose(mps_x, 1, 2).to('cpu')
7129        self.assertEqual(cpu_transpose3, mps_transpose3)
7130
7131
7132    def test_transpose_4D(self):
7133        values = [[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]],
7134                  [[[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]], [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]]]
7135        cpu_x = torch.tensor(values, device='cpu')
7136        mps_x = torch.tensor(values, device='mps')
7137
7138        cpu_transpose1 = torch.transpose(cpu_x, 0, 1)
7139        mps_transpose1 = torch.transpose(mps_x, 0, 1).to('cpu')
7140        self.assertEqual(cpu_transpose1, mps_transpose1)
7141
7142        cpu_transpose2 = torch.transpose(cpu_x, 0, 2)
7143        mps_transpose2 = torch.transpose(mps_x, 0, 2).to('cpu')
7144        self.assertEqual(cpu_transpose2, mps_transpose2)
7145
7146        cpu_transpose3 = torch.transpose(cpu_x, 0, 3)
7147        mps_transpose3 = torch.transpose(mps_x, 0, 3).to('cpu')
7148        self.assertEqual(cpu_transpose3, mps_transpose3)
7149
7150        cpu_transpose4 = torch.transpose(cpu_x, 3, 1)
7151        mps_transpose4 = torch.transpose(mps_x, 3, 1).to('cpu')
7152        self.assertEqual(cpu_transpose4, mps_transpose4)
7153
7154        cpu_transpose5 = torch.transpose(cpu_x, 3, 2)
7155        mps_transpose5 = torch.transpose(mps_x, 3, 2).to('cpu')
7156        self.assertEqual(cpu_transpose5, mps_transpose5)
7157
7158        cpu_transpose6 = torch.transpose(cpu_x, 1, 2)
7159        mps_transpose6 = torch.transpose(mps_x, 1, 2).to('cpu')
7160        self.assertEqual(cpu_transpose6, mps_transpose6)
7161
7162    # Test sign
7163    def test_sign(self):
7164        def helper(shape):
7165            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7166            x = cpu_x.detach().clone().to('mps').requires_grad_()
7167
7168            sign_result = torch.sign(x)
7169            sign_result_cpu = torch.sign(cpu_x)
7170
7171            cpu_grad = torch.ones_like(sign_result_cpu)
7172            grad = cpu_grad.to('mps')
7173
7174            sign_result.backward(gradient=grad)
7175            sign_result_cpu.backward(gradient=cpu_grad)
7176
7177            self.assertEqual(sign_result, sign_result_cpu)
7178
7179        helper((2, 8, 4, 5))
7180
7181    def test_signbit(self):
7182        def helper(shape, dtype):
7183            cpu_x = torch.randn(shape, device='cpu').to(dtype)
7184            x = cpu_x.clone().to('mps')
7185
7186            signbit_result = torch.signbit(x)
7187            signbit_result_cpu = torch.signbit(cpu_x)
7188
7189            self.assertEqual(signbit_result, signbit_result_cpu)
7190
7191        helper((2, 8, 4, 5), torch.int)
7192        helper((2, 8, 4, 5), torch.float)
7193        helper((2, 8, 4, 5), torch.int64)
7194
7195    # Test neg
7196    def test_neg(self):
7197        def helper(shape):
7198            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7199            x = cpu_x.detach().clone().to('mps').requires_grad_()
7200
7201            neg_result = torch.neg(x)
7202            neg_result_cpu = torch.neg(cpu_x)
7203
7204            cpu_grad = torch.ones_like(neg_result_cpu)
7205            grad = cpu_grad.to('mps')
7206
7207            neg_result.backward(gradient=grad)
7208            neg_result_cpu.backward(gradient=cpu_grad)
7209
7210            self.assertEqual(neg_result, neg_result_cpu)
7211
7212        helper((2, 8, 4, 5))
7213
7214    def test_neg_strided_input(self):
7215        # See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337
7216        x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
7217        y = x.permute(1, 0, 2)[..., 1]
7218        z = y + y.neg()
7219        self.assertEqual(z.abs().max().item(), 0.0)
7220
7221    # Test index add
7222    def test_index_add(self):
7223        def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):
7224            cpu_x = torch.randn(shape, device='cpu', dtype=x_dtype, requires_grad=False)
7225            x = cpu_x.detach().clone().to('mps')
7226
7227            cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7228            idx = cpu_idx.detach().clone().to('mps')
7229
7230            cpu_source = torch.randn(source_shape, device='cpu', dtype=x_dtype, requires_grad=False)
7231            source = cpu_source.detach().clone().to('mps')
7232
7233            idx_result = torch.index_add(x, dim=dim, index=idx, source=source, alpha=alpha)
7234            idx_result_cpu = torch.index_add(cpu_x, dim=dim, index=cpu_idx, source=cpu_source, alpha=alpha)
7235            self.assertEqual(idx_result, idx_result_cpu)
7236
7237        helper((2, 8, 4, 5), 0, [0, 1, 0], (3, 8, 4, 5), 5)
7238        helper((8, 8, 4, 5), 0, [7], (1, 8, 4, 5), 6.0)
7239        helper((2, 8, 4, 5), 1, [0, 3, 7], (2, 3, 4, 5), 5)
7240        helper((2, 8, 4, 5), 2, [3, 0], (2, 8, 2, 5), 3.0)
7241        helper((2, 8, 4, 5), 3, [2, 3, 0], (2, 8, 4, 3), 4)
7242        helper((2, 3, 3), -1, [1, 2], (2, 3, 2), 6.0)
7243        # test result dim=1
7244        helper((2,), 0, [1], (1,), 6.0)
7245        helper(2, 0, 1, 1, 6)
7246        # test float16
7247        helper((2,), 0, [1], (1,), 6.0, x_dtype=torch.float16)
7248
7249    def test_index_64bit(self):
7250        """ Test that index operations work for 4Gb+ tensors """
7251        if product_version < 14.0:
7252            raise unittest.SkipTest("Sonoma is needed for large tensors, see https://github.com/pytorch/pytorch/issues/84039")
7253        # Cleanup memory
7254        gc.collect()
7255        torch.mps.empty_cache()
7256        # Check that index operations work for 4+GB tensors
7257        x = torch.rand(16000, 67120, device="mps")
7258        self.assertGreater(x.element_size() * x.numel(), 2**32)
7259        idx = torch.arange(0, 2, device="mps")
7260        x_sampled = x[:, idx]
7261        self.assertEqual(x[:, 0], x_sampled[:, 0])
7262        # Reclaim memory after running the tests
7263        del x
7264        gc.collect()
7265        torch.mps.empty_cache()
7266
7267    def test_mm_large(self):
7268        """ Test that MM works for matrices with index larger than 32K """
7269        x = torch.rand(10, 1, device="mps")
7270        y = torch.rand(1, 32769, device="mps")
7271        # This used to crash with:
7272        # error: subRange.start (24576) is not less than length of dimension[0] (16384)
7273        # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
7274        self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
7275
7276        def compare_mm(m, n, k, dtype=torch.float):
7277            x = torch.rand(m, n, device="mps", dtype=dtype)
7278            y = torch.rand(n, k, device="mps", dtype=dtype)
7279            z = torch.mm(x, y).cpu()
7280            z_cpu = torch.mm(x.cpu(), y.cpu())
7281            self.assertEqual(z, z_cpu)
7282
7283        # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
7284        compare_mm(1024, 1, 32769)
7285        # one more time, but with dimensions inverted
7286        # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
7287        compare_mm(32769, 1, 1025)
7288
7289        if product_version >= 14.0:
7290            # Test bfloat16 mm
7291            compare_mm(1024, 1, 32769, torch.bfloat16)
7292
7293    @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test")
7294    @unittest.skipIf(product_version < 14.0, "Can't allocate 4Gb tensor on MacOS 13")
7295    def test_copy_large(self):
7296        """ Test that copy of 4Gb+ tensors works """
7297        x = torch.ones((2**30 + 11,), dtype=torch.float32)
7298        y = x.to(device="mps")
7299        self.assertTrue(torch.all(y == torch.tensor(1.0, device="mps")))
7300        del y
7301        del x
7302
7303    # Test flip
7304    def test_flip(self):
7305        def helper(shape, dims):
7306            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7307            x = cpu_x.detach().clone().to('mps')
7308
7309            flip_result = torch.flip(x, dims=dims)
7310            flip_result_cpu = torch.flip(cpu_x, dims=dims)
7311
7312            self.assertEqual(flip_result, flip_result_cpu)
7313
7314        helper((2, 8, 4, 5), [0])
7315        helper((8, 8, 4, 5), [0, 1])
7316        helper((2, 8, 4, 5), (0, 1, 2, 3))
7317        helper((2, 3, 3), (-1,))
7318        # empty dims
7319        helper((2, 8, 4, 5), [])
7320        # input.numel() == 1
7321        helper((1,), (0,))
7322        # input.numel() == 0
7323        helper((0,), (0,))
7324        # none of dims that needs to be flipped
7325        helper((1, 3), [0])
7326
7327    # Test index select
7328    def test_index_select(self):
7329        def helper(shape, dim, index, idx_dtype=torch.int32):
7330            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
7331            x = cpu_x.detach().clone().to('mps')
7332
7333            cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7334            idx = cpu_idx.detach().clone().to('mps')
7335
7336            idx_result = torch.index_select(x, dim=dim, index=idx)
7337            idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7338
7339            self.assertEqual(idx_result, idx_result_cpu)
7340
7341        helper((2, 8, 4, 5), 0, [1])
7342        helper((8, 8, 4, 5), 0, [0, 3, 2, 7, 6])
7343        helper((2, 8, 4, 5), 1, [0, 3, 2, 7, 6])
7344        helper((2, 8, 4, 5), 2, [3, 0, 1])
7345        helper((2, 8, 4, 5), 3, [2, 3, 0])
7346        helper((2, 3, 3), -1, [1, 2])
7347        helper((), 0, [0])
7348        helper((5), 0, [])
7349
7350    def test_index_select_scalar(self):
7351        def helper(value, dim, index, idx_dtype=torch.int32):
7352            cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
7353            x = cpu_x.detach().clone().to('mps')
7354
7355            cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
7356            idx = cpu_idx.detach().clone().to('mps')
7357
7358            idx_result = torch.index_select(x, dim=dim, index=idx)
7359            idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
7360
7361            self.assertEqual(idx_result, idx_result_cpu)
7362
7363        helper(22, 0, [0])
7364        with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"):
7365            helper(22, 0, [])
7366
7367    def test_embedding_dense_backward(self):
7368        def helper(n, d, m, idx):
7369            embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps')
7370            emedding_weight = embeddingMPS.weight.detach().cpu()
7371            W_MPS = torch.randn((m, d), requires_grad=True, device='mps')
7372            idx_MPS = torch.tensor(idx, device='mps')
7373            a_MPS = embeddingMPS.weight.clone() @ W_MPS.t()  # weight must be cloned for this to be differentiable
7374            a_MPS.retain_grad()
7375            b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t()  # modifies weight in-place
7376            b_MPS.retain_grad()
7377            out_MPS = (a_MPS.unsqueeze(0) + b_MPS)
7378            loss_MPS = out_MPS.sigmoid().prod()
7379            loss_MPS.backward()
7380
7381            embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight)
7382            W_CPU = W_MPS.to('cpu')
7383            idx_CPU = torch.tensor(idx)
7384            a_CPU = embeddingCPU.weight.clone() @ W_CPU.t()  # weight must be cloned for this to be differentiable
7385            a_CPU.retain_grad()
7386            b_CPU = embeddingCPU(idx_CPU) @ W_CPU.t()  # modifies weight in-place
7387            b_CPU.retain_grad()
7388            out_CPU = (a_CPU.unsqueeze(0) + b_CPU)
7389            loss_CPU = out_CPU.sigmoid().prod()
7390            loss_CPU.backward()
7391
7392            self.assertEqual(b_CPU.grad, b_MPS.grad)
7393            self.assertEqual(a_CPU.grad, a_MPS.grad)
7394
7395        helper(3, 5, 7, [0, 1, 2])
7396        helper(3, 6, 7, [0, 1, 2])  # verify if changes in shape would cause cached graph lookup problems
7397        helper(3, 5, 7, 2)  # test scalar index
7398
7399    # Test pytorch gather
7400    def test_gather(self):
7401        def helper(shape, dim, idx_shape, idx_dtype=torch.int64):
7402            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7403            x = cpu_x.detach().clone().to('mps').requires_grad_()
7404
7405            # Indices should be taken from range of axis along which gathering is done
7406            idx_np = np.random.randint(0, shape[dim], idx_shape)
7407
7408            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7409            idx = cpu_idx.detach().clone().to('mps')
7410
7411            gather_result = torch.gather(x, dim=dim, index=idx)
7412            gather_result_cpu = torch.gather(cpu_x, dim=dim, index=cpu_idx)
7413
7414            cpu_grad = torch.randn(idx_shape, device='cpu', dtype=torch.float)
7415            grad = cpu_grad.to('mps')
7416            gather_result.backward(gradient=grad)
7417            gather_result_cpu.backward(gradient=cpu_grad)
7418
7419            self.assertEqual(gather_result, gather_result_cpu)
7420            self.assertEqual(cpu_x.grad, x.grad)
7421
7422        helper((6, 3, 3), 0, (3, 3, 3))
7423        helper((2, 3, 3, 3), 0, (10, 3, 3, 3))
7424        helper((2, 8, 4, 5), 0, (10, 8, 4, 5))
7425        helper((2, 8, 4, 5), 0, (10, 6, 3, 2))
7426        helper((8, 8, 4, 5), 0, (6, 8, 4, 5))
7427        helper((8, 8, 4, 5), 0, (6, 7, 2, 3))
7428        helper((2, 8, 4, 5), 1, (2, 5, 3, 4))
7429        helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
7430        helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
7431
7432    # Test pytorch gather
7433    def test_gather_scalar(self):
7434        idx_dtype = torch.int64
7435        cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7436        x = cpu_x.detach().clone().to('mps').requires_grad_()
7437
7438        idx_np = [0]
7439
7440        cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7441        idx = cpu_idx.detach().clone().to('mps')
7442
7443        gather_result = torch.gather(x, dim=0, index=idx)
7444        gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
7445
7446        cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
7447        grad = cpu_grad.to('mps')
7448        gather_result.backward(gradient=grad)
7449        gather_result_cpu.backward(gradient=cpu_grad)
7450
7451        self.assertEqual(gather_result, gather_result_cpu)
7452        self.assertEqual(cpu_x.grad, x.grad)
7453
7454    # Test pytorch scatter_add and scatter
7455    def test_scatter_add(self):
7456        def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
7457            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7458            x = cpu_x.detach().clone().to('mps').requires_grad_()
7459
7460            cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7461            src = cpu_src.detach().clone().to('mps').requires_grad_()
7462
7463            # Indices should be taken from range of axis along which gathering is done
7464            idx_np = None
7465            if (do_add):
7466                idx_np = np.random.randint(0, shape[dim], idx_shape)
7467            else:
7468                idx_np = np.array([[0, 1, 2],
7469                                   [1, 2, 3],
7470                                   [2, 3, 4],
7471                                   [3, 4, 5],
7472                                   [4, 5, 6]])
7473
7474            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7475            idx = cpu_idx.detach().clone().to('mps')
7476
7477            scatter_result = None
7478            scatter_result_cpu = None
7479
7480            if (do_add):
7481                scatter_result = torch.scatter_add(x, dim=dim, index=idx, src=src)
7482                scatter_result_cpu = torch.scatter_add(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7483            else:
7484                scatter_result = torch.scatter(x, dim=dim, index=idx, src=src)
7485                scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src)
7486
7487            cpu_grad = None
7488            grad = None
7489
7490            if (idx_shape == src_shape):
7491                cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7492                grad = cpu_grad.to('mps')
7493                scatter_result.backward(gradient=grad)
7494                scatter_result_cpu.backward(gradient=cpu_grad)
7495
7496            self.assertEqual(scatter_result, scatter_result_cpu)
7497            if (idx_shape == src_shape):
7498                self.assertEqual(cpu_x.grad, x.grad)
7499                self.assertEqual(cpu_src.grad, src.grad)
7500
7501        helper((2, 3), 0, (5, 3), (5, 3))
7502        helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7503        helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5))
7504        helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2))
7505        helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2))
7506        helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5))
7507
7508        helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5))
7509        helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2))
7510        helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3))
7511        helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3))
7512
7513        helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8))
7514        helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6))
7515        helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6))
7516
7517        # Test scatter src
7518        helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
7519        helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
7520
7521    # Test pytorch scatter_add and scatter for scalar input
7522    def test_scatter_add_scalar(self):
7523        def helper(idx_dtype=torch.int64, do_add=True):
7524            cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
7525            x = cpu_x.detach().clone().to('mps').requires_grad_()
7526
7527            cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
7528            src = cpu_src.detach().clone().to('mps').requires_grad_()
7529
7530            # Indices should be taken from range of axis along which gathering is done
7531            idx_np = [0]
7532
7533            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7534            idx = cpu_idx.detach().clone().to('mps')
7535
7536            scatter_result = None
7537            scatter_result_cpu = None
7538
7539            if (do_add):
7540                scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
7541                scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7542            else:
7543                scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
7544                scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
7545
7546            cpu_grad = None
7547            grad = None
7548
7549            cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
7550            grad = cpu_grad.to('mps')
7551            scatter_result.backward(gradient=grad)
7552            scatter_result_cpu.backward(gradient=cpu_grad)
7553
7554            self.assertEqual(scatter_result, scatter_result_cpu)
7555            self.assertEqual(cpu_x.grad, x.grad)
7556            self.assertEqual(cpu_src.grad, src.grad)
7557
7558        helper()
7559        helper(do_add=False)
7560
7561    # Test pytorch scatter_reduce
7562    def test_scatter_reduce(self):
7563        def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):
7564            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7565            x = cpu_x.detach().clone().to('mps').requires_grad_()
7566
7567            cpu_src = torch.randn(src_shape, device='cpu', dtype=torch.float, requires_grad=True)
7568            src = cpu_src.detach().clone().to('mps').requires_grad_()
7569
7570            # Indices should be taken from range of axis along which gathering is done
7571            idx_np = np.random.randint(0, shape[dim], idx_shape)
7572
7573            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
7574            idx = cpu_idx.detach().clone().to('mps')
7575
7576            scatter_result = torch.scatter(x, dim=dim, index=idx, src=src, reduce=reduce_str)
7577            scatter_result_cpu = torch.scatter(cpu_x, dim=dim, index=cpu_idx, src=cpu_src, reduce=reduce_str)
7578
7579            self.assertEqual(scatter_result, scatter_result_cpu)
7580
7581        # for reduce in ["sum", "prod", "amax", "amin"]:
7582        for reduce_type in ["add", "multiply"]:
7583            helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
7584            helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7585            helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
7586            helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7587            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
7588            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
7589
7590            helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
7591            helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
7592            helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
7593            helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
7594
7595            helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
7596            helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
7597            helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
7598
7599    def test_is_nonzero(self):
7600        self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))
7601        self.assertTrue(torch.is_nonzero(torch.tensor([1.5]).to('mps')))
7602        self.assertFalse(torch.is_nonzero(torch.tensor([False]).to('mps')))
7603        self.assertTrue(torch.is_nonzero(torch.tensor([3]).to('mps')))
7604
7605    # Test triu
7606    def test_triu(self):
7607        def helper(shape, diag=0):
7608            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7609            x = cpu_x.detach().clone().to('mps').requires_grad_()
7610
7611            triu_result = torch.triu(x, diag)
7612            triu_result_cpu = torch.triu(cpu_x, diag)
7613
7614            cpu_grad = torch.randn(triu_result_cpu.shape)
7615            grad = cpu_grad.to('mps')
7616
7617            triu_result.backward(gradient=grad)
7618            triu_result_cpu.backward(gradient=cpu_grad)
7619
7620            self.assertEqual(triu_result, triu_result_cpu)
7621            self.assertEqual(x.grad, cpu_x.grad)
7622
7623        helper((2, 8, 4, 5))
7624        helper((2, 8, 4, 5), diag=1)
7625        helper((2, 8, 4, 5), diag=2)
7626        helper((2, 8, 4, 5), diag=3)
7627        helper((2, 8, 4, 5), diag=-1)
7628        helper((2, 8, 4, 5), diag=-2)
7629        helper((2, 8, 4, 5), diag=-3)
7630
7631    # Test inverse
7632    def test_inverse(self):
7633        def helper(n):
7634            cpu_input = torch.randn(n, n, device='cpu')
7635            mps_input = cpu_input.to('mps')
7636
7637            cpu_result = torch.linalg.inv(cpu_input)
7638            mps_result = torch.linalg.inv(mps_input)
7639            self.assertEqual(cpu_result, mps_result)
7640
7641        helper(2)
7642        helper(6)
7643        helper(3)
7644        helper(8)
7645
7646    # Test tril
7647    def test_tril(self):
7648        def helper(shape, diag=0):
7649            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7650            x = cpu_x.detach().clone().to('mps').requires_grad_()
7651
7652            tril_result = torch.tril(x, diag)
7653            tril_result_cpu = torch.tril(cpu_x, diag)
7654
7655            cpu_grad = torch.randn(tril_result_cpu.shape)
7656            grad = cpu_grad.to('mps')
7657
7658            tril_result.backward(gradient=grad)
7659            tril_result_cpu.backward(gradient=cpu_grad)
7660
7661            self.assertEqual(tril_result, tril_result_cpu)
7662            self.assertEqual(x.grad, cpu_x.grad)
7663
7664        helper((2, 8, 4, 5))
7665        helper((2, 8, 4, 5), diag=1)
7666        helper((2, 8, 4, 5), diag=2)
7667        helper((2, 8, 4, 5), diag=3)
7668        helper((2, 8, 4, 5), diag=-1)
7669        helper((2, 8, 4, 5), diag=-2)
7670        helper((2, 8, 4, 5), diag=-3)
7671
7672    # test eye
7673    def test_eye(self):
7674        def helper(n, m, dtype):
7675            cpu_result = None
7676            result = None
7677
7678            if (n == m):
7679                cpu_result = torch.eye(n, dtype=dtype, device='cpu')
7680                result = torch.eye(n, dtype=dtype, device='mps')
7681            else:
7682                cpu_result = torch.eye(n, m, device='cpu')
7683                result = torch.eye(n, m, device='mps')
7684
7685            self.assertEqual(result, cpu_result)
7686
7687        for dtype in [torch.bool, torch.float16, torch.float32, torch.uint8, torch.int16, torch.int32, torch.int64]:
7688            helper(2, 2, dtype)
7689            helper(2, 3, dtype)
7690            helper(0, 2, dtype)
7691            helper(0, 0, dtype)
7692            helper(3, 8, dtype)
7693            helper(8, 3, dtype)
7694
7695    # Test diag
7696    def test_diag(self):
7697        def helper(shape, diag=0):
7698            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7699            x = cpu_x.detach().clone().to('mps').requires_grad_()
7700
7701            diag_result = torch.diag(x, diag)
7702            diag_result_cpu = torch.diag(cpu_x, diag)
7703
7704            # cpu_grad = torch.randn(diag_result_cpu.shape)
7705            # grad = cpu_grad.to('mps')
7706
7707            # diag_result.backward(gradient=grad)
7708            # diag_result_cpu.backward(gradient=cpu_grad)
7709
7710            self.assertEqual(diag_result, diag_result_cpu)
7711            # self.assertEqual(x.grad, cpu_x.grad)
7712
7713        for shape in [(5, 5), (5, 6), (6, 5), (5,), (6,)]:
7714            for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
7715                helper(shape, diag=diag)
7716
7717    # Test linspace
7718    def test_linspace(self):
7719        def helper(start, end, steps, dtype=torch.float32):
7720            cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
7721            result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
7722            self.assertEqual(cpu_result, result)
7723
7724        for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
7725            helper(2, 5, 10, dtype)
7726            helper(2, 2, 10, dtype)
7727            helper(5, 2, 10, dtype)
7728            helper(2, 2, 0, dtype)
7729
7730    # Test argange
7731    def test_arange(self):
7732        self.assertEqual(np.arange(10), torch.arange(10, device='mps'))
7733        self.assertEqual(np.arange(7, 1, -1), torch.arange(7, 1, -1, device='mps'))
7734        self.assertEqual(np.arange(1, 2, .3, dtype=np.float32), torch.arange(1, 2, .3, device='mps'))
7735        self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(6.3, device='mps'))
7736
7737    def test_arange_empty(self):
7738        out_mps = torch.tensor([], device="mps")
7739        out_cpu = torch.tensor([], device="cpu")
7740
7741        y_mps = torch.arange(0, 0, 1, out=out_mps)
7742        y_cpu = torch.arange(0, 0, 1, out=out_cpu)
7743        self.assertEqual(y_mps, y_cpu)
7744
7745    # Test rgange
7746    def test_range(self):
7747        self.assertEqual(np.arange(11, dtype=np.float32), torch.range(0, 10, device='mps'))
7748        self.assertEqual(np.arange(7, 0, -1, dtype=np.float32), torch.range(7, 1, -1, device='mps'))
7749        self.assertEqual(np.array([1.0000, 1.3000, 1.6000, 1.9000], dtype=np.float32), torch.range(1, 2, .3, device='mps'))
7750        self.assertEqual(np.arange(6.3, dtype=np.float32), torch.arange(0, 6.3, device='mps'))
7751
7752    # Test softmax
7753    def test_softmax(self):
7754        def helper(shape, dim, channels_last=False):
7755            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
7756            if (channels_last):
7757                cpu_x = cpu_x.to(memory_format=torch.channels_last)
7758                cpu_x.retain_grad()
7759            x = cpu_x.detach().clone().to('mps').requires_grad_()
7760
7761            softmax_result = torch.nn.functional.softmax(x, dim=dim)
7762            softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7763
7764            # Currently NOT testing backward for channels last backward
7765            cpu_grad = None
7766            grad = None
7767
7768            if (not channels_last):
7769                cpu_grad = torch.randn(shape, device='cpu', dtype=torch.float)
7770                grad = cpu_grad.to('mps')
7771
7772                softmax_result.backward(gradient=grad)
7773                softmax_result_cpu.backward(gradient=cpu_grad)
7774
7775            self.assertEqual(softmax_result, softmax_result_cpu)
7776            if (not channels_last):
7777                self.assertEqual(x.grad, cpu_x.grad)
7778
7779        def helper2(dim):
7780            cpu_x = torch.tensor(1.23, device='cpu', dtype=torch.float, requires_grad=True)
7781            x = cpu_x.detach().clone().to('mps').requires_grad_()
7782
7783            softmax_result = torch.nn.functional.softmax(x, dim=dim)
7784            softmax_result_cpu = torch.nn.functional.softmax(cpu_x, dim=dim)
7785
7786            cpu_grad = torch.tensor(2.34, device='cpu', dtype=torch.float)
7787            grad = cpu_grad.to('mps')
7788
7789            softmax_result.backward(gradient=grad)
7790            softmax_result_cpu.backward(gradient=cpu_grad)
7791
7792            self.assertEqual(softmax_result, softmax_result_cpu)
7793            self.assertEqual(x.grad, cpu_x.grad)
7794
7795        helper2(0)
7796
7797        for channels_last in [False]:
7798            for shape in [(2, 4, 8, 5), (3, 4, 6, 7, 2)]:
7799                if (len(shape) != 4 and channels_last):
7800                    continue
7801                for dim in [0, 1, 2, 3, -1, -2, -3]:
7802                    helper(shape, dim, channels_last)
7803
7804    def test_nan_to_num(self):
7805        inputCPU = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.14])
7806        inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
7807        outputCPU = torch.nan_to_num(inputCPU, nan=2.0, posinf=1.0, neginf=-1.0)
7808        outputMPS = torch.nan_to_num(inputMPS, nan=2.0, posinf=1.0, neginf=-1.0)
7809        self.assertEqual(outputMPS, outputCPU)
7810
7811    # Test where
7812    def test_where(self):
7813        def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
7814
7815            cpu_cond = torch.randint(2, shape, device='cpu', dtype=cond_dtype, requires_grad=False)
7816            cond = cpu_cond.detach().clone().to('mps')
7817
7818            cpu_x = torch.randn(x_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7819            x = cpu_x.detach().clone().to('mps').requires_grad_()
7820
7821            cpu_y = torch.randn(y_shape, device='cpu', dtype=x_dtype, requires_grad=True)
7822            y = cpu_y.detach().clone().to('mps').requires_grad_()
7823
7824            cpu_out = torch.where(cpu_cond, cpu_x, cpu_y)
7825            out = torch.where(cond, x, y)
7826
7827            cpu_grad = torch.randn(cpu_out.shape)
7828            grad = cpu_grad.to('mps')
7829
7830            cpu_out.backward(gradient=cpu_grad)
7831            out.backward(gradient=grad)
7832
7833            self.assertEqual(out, cpu_out)
7834            self.assertEqual(x.grad, cpu_x.grad)
7835            self.assertEqual(y.grad, cpu_y.grad)
7836
7837        for shape in ([(0, 3), [], (2, 3), (9,)]):
7838            helper(shape, shape, shape)
7839
7840        helper((2, 3, 1), (2, 3, 4), (2, 1, 4))
7841        helper((2, 1, 1), (2, 3, 4), (1, 3, 4))
7842        helper((1, 1, 1), (1, 1, 4), (2, 3, 1))
7843        helper([], (1, 1, 4), (2, 3, 1))
7844        helper([], (2, 3, 4), [])
7845        helper((5, 2, 3), (2, 3), (2, 3))
7846        helper((2, 3), (5, 2, 3), (2, 3))
7847        helper((2, 3), (2, 3), (5, 2, 3))
7848        helper((2, 3), (5, 2, 3), (6, 5, 2, 3))
7849        # Test that output is correctly resizes
7850        # TODO: Remove me when out OpInfo testing is enabled on MPS
7851        output = torch.tensor(0.0, device="mps")
7852        cond = torch.randint(2, (3, 3), dtype=torch.bool, device="mps")
7853        inp = torch.rand(3, 3, device="mps")
7854        other = torch.rand(3, 3, device="mps")
7855        out = torch.where(cond, inp, other, out=output)
7856        self.assertEqual(id(out), id(output))
7857        self.assertEqual(out.shape, (3, 3))
7858
7859    # Test normal
7860    def test_normal(self):
7861        def helper(shape, mean=0.0, std=1.0):
7862            mps_out = torch.normal(mean, std, shape, device='mps')
7863
7864            mean_array = np.ones(shape)
7865            mean_array *= mean
7866            cpu_mean_tensor = torch.tensor(mean_array, device='cpu', dtype=torch.float, requires_grad=False)
7867            mean_tensor = cpu_mean_tensor.detach().clone().to('mps')
7868
7869            std_array = np.ones(shape)
7870            std_array *= std
7871            cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
7872            std_tensor = cpu_std_tensor.detach().clone().to('mps')
7873
7874            # test out
7875            mps_out = torch.zeros(shape, device='mps')
7876            torch.normal(mean_tensor, std, out=mps_out)
7877
7878            mps_out = torch.zeros(shape, device='mps')
7879            torch.normal(mean, std_tensor, out=mps_out)
7880
7881            mps_out = torch.zeros(shape, device='mps')
7882            torch.normal(mean_tensor, std_tensor, out=mps_out)
7883
7884            # test without out
7885            mps_out = torch.normal(mean_tensor, std)
7886            self.assertEqual(mps_out.size(), mean_tensor.size())
7887
7888            mps_out = torch.normal(mean, std_tensor)
7889            self.assertEqual(mps_out.size(), std_tensor.size())
7890
7891            inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
7892            mps_out = torch.normal(mean_tensor, std_tensor)
7893            self.assertEqual(mps_out.size(), inferred_shape)
7894
7895        helper((2, 3, 4, 5, 6))
7896        helper((100, 100), 2.5, 1.2)
7897
7898    def test_bernoulli(self):
7899        shape = (10, 10)
7900        all_ones = torch.ones(shape, device='mps')
7901        all_zeros = torch.zeros(shape, device='mps')
7902
7903        prob_tensor = all_ones * 0.5
7904        # probability of drawing "1" is 0.5
7905        mps_out = torch.bernoulli(prob_tensor)
7906        # We can't check reliably the mean and std.
7907        # Just make sure we don't return constant values
7908        self.assertNotEqual(mps_out.to('cpu').mean(), 0.)
7909        self.assertNotEqual(mps_out.to('cpu').std() ** 2, 0.)
7910
7911        # probability of drawing "1" is 0
7912        mps_out = torch.bernoulli(all_zeros)
7913        self.assertEqual(mps_out, all_zeros)
7914
7915        # probability of drawing "1" is 1
7916        mps_out = torch.bernoulli(all_ones)
7917        self.assertEqual(mps_out, all_ones)
7918
7919        # Check it works for different dtypes
7920        for dtype in [torch.float16, torch.int8, torch.int16, torch.int32, torch.int64]:
7921            mps_out = torch.zeros(shape, device='mps', dtype=dtype).bernoulli(0.5)
7922            # Check that output is not all zeros or ones
7923            if product_version > 13.0:
7924                uniq = mps_out.unique()
7925                self.assertEqual(uniq, torch.arange(2, device='mps', dtype=dtype))
7926            else:
7927                self.assertEqual(mps_out.min().item(), 0.)
7928                self.assertEqual(mps_out.max().item(), 1.)
7929
7930    def test_mps_generator(self):
7931        # explicit manual seeding by creating an MPS Generator
7932        g_mps = torch.Generator(device='mps')
7933        g_mps.manual_seed(999)
7934        mps_x = torch.randn(5, device='mps', generator=g_mps)
7935        g_mps.manual_seed(999)
7936        # generate random numbers with offset `0`
7937        mps_y = torch.randn(5, device='mps', generator=g_mps)
7938        # seed values were the same, so the random tensor contents should match
7939        self.assertEqual(mps_x, mps_y)
7940        # save generator's state (offset = 1) to restore it later
7941        g_state = g_mps.get_state()
7942
7943        # generate random numbers with offset `1`
7944        mps_x = torch.randn(5, device='mps', generator=g_mps)
7945        # in this case, the random results must differ from the last generated random results
7946        self.assertNotEqual(mps_x, mps_y)
7947
7948        # mps_x was produced by g_state, we use it as our reference mps_y.
7949        mps_y = mps_x
7950
7951        # restore the previously saved state, and the results should match again
7952        g_mps.set_state(g_state)
7953        mps_x = torch.randn(5, device='mps', generator=g_mps)
7954        self.assertEqual(mps_x, mps_y)
7955
7956    @serialTest()
7957    def test_default_mps_generator(self):
7958        # manual seeding on the "default" MPS generator using
7959        # the global torch.manual_seed()
7960        torch.manual_seed(230)
7961        mps_x = torch.randn(5, device='mps')
7962        # manual seeding using torch.mps.manual_seed()
7963        # which should set the "default" MPS generator
7964        # like the global torch.manual_seed()
7965        torch.mps.manual_seed(230)
7966        # generate random numbers with offset `0`
7967        mps_y = torch.randn(5, device='mps')
7968        # seed values were the same, so the random tensor contents should match
7969        self.assertEqual(mps_x, mps_y)
7970
7971        # save the default generator's state (offset = 1) to restore it later
7972        g_state = torch.mps.get_rng_state()
7973
7974        # generate random numbers with offset `1`
7975        mps_x = torch.randn(5, device='mps')
7976        # in this case, the random results must differ from the last generated random results
7977        self.assertNotEqual(mps_x, mps_y)
7978        # since we called randn twice after seeding, the offset should be 2
7979        self.assertEqual(torch.mps._get_default_mps_generator().get_offset(), 2)
7980
7981        # mps_x was produced by g_state, we use it as our reference mps_y.
7982        mps_y = mps_x
7983
7984        # restore the previously saved state to the "default" MPS generator, and the results should match again
7985        torch.mps.set_rng_state(g_state)
7986        mps_x = torch.randn(5, device='mps')
7987        self.assertEqual(mps_x, mps_y)
7988
7989    def test_device_synchronize(self):
7990        # just running some ops each followed by a synchronize to wait for
7991        # MPS stream to finish running each of them
7992        net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
7993            .to(device='mps', dtype=torch.float)
7994
7995        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
7996        torch.mps.synchronize()
7997        x = net1(x)
7998        torch.mps.synchronize()
7999        x.backward(torch.randn_like(x))
8000        torch.mps.synchronize()
8001
8002    @serialTest()
8003    def test_mps_allocator_module(self):
8004        # first garbage collect and empty the cached blocks
8005        gc.collect()
8006        torch.mps.empty_cache()
8007        # measure memory allocations from MPSAllocator
8008        current_alloc_before = torch.mps.current_allocated_memory()
8009        # after garbage collection and emptying the cache the
8010        # current_allocated_memory must be zero
8011        self.assertEqual(current_alloc_before, 0)
8012        # measure total memory allocations from Metal driver
8013        driver_alloc_before = torch.mps.driver_allocated_memory()
8014        # allocate a new 8 MB tensor to force allocation of a new Metal Heap
8015        x = torch.ones(1024 * 1024 * 8, device="mps")
8016        # get memory allocations after allocating tensor x
8017        current_alloc_after = torch.mps.current_allocated_memory()
8018        driver_alloc_after = torch.mps.driver_allocated_memory()
8019        # current and driver memory allocations must have
8020        # grown at this point
8021        self.assertGreater(current_alloc_after, current_alloc_before)
8022        self.assertGreater(driver_alloc_after, driver_alloc_before)
8023
8024    def test_mps_allocator_stats(self):
8025        max_memory = torch.mps.recommended_max_memory()
8026        print(f"Recommended Max Memory : {max_memory/ 1024 ** 3} GB")
8027        self.assertGreater(max_memory, 0)
8028
8029    # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
8030    # press record, then run this python test, and press stop. Next expand
8031    # the os_signposts->PyTorchMPS and check if events or intervals are logged
8032    # like this example:
8033    # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
8034    def test_mps_profiler_module(self):
8035        with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
8036            # just running some ops to capture the OS Signposts traces for profiling
8037            net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
8038                .to(device='mps', dtype=torch.float)
8039            x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8040            x = net1(x)
8041
8042        torch.mps.profiler.start(mode="interval", wait_until_completed=True)
8043        # just running some ops to capture the OS Signposts traces for profiling
8044        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8045        x = net1(x)
8046        torch.mps.profiler.stop()
8047
8048    def test_mps_event_module(self):
8049        startEvent = torch.mps.Event(enable_timing=True)
8050        startEvent.record()
8051        net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
8052            .to(device='mps', dtype=torch.float)
8053        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
8054        x = net1(x)
8055        endEvent = torch.mps.Event(enable_timing=True)
8056        endEvent.record()
8057        elapsedTime = startEvent.elapsed_time(endEvent)
8058        self.assertGreater(elapsedTime, 0.0)
8059
8060    def test_jit_save_load(self):
8061        m = torch.nn.Module()
8062        m.x = torch.rand(3, 3, device='mps')
8063        buffer = io.BytesIO()
8064        torch.jit.save(torch.jit.script(m), buffer)
8065        buffer.seek(0)
8066        n = torch.jit.load(buffer)
8067        self.assertEqual(n.x, m.x)
8068
8069    # Test random_, random_.to and random_.from
8070    def test_random(self):
8071        def helper(shape, low, high, dtype=torch.int32):
8072
8073            mps_out = torch.randint(low, high, shape, dtype=dtype, device='mps')
8074
8075            # We can't check reliably the mean and std.
8076            # Just make sure we don't return constant values
8077            self.assertNotEqual(mps_out.float().mean().item(), 0.)
8078            self.assertNotEqual(mps_out.float().std().item(), 0.)
8079
8080        helper([100, 100], 0, 10)
8081        helper([100, 100], 23, 89)
8082        helper([100, 100], 23, 89, dtype=torch.float32)
8083        helper([100, 100], 23, 89, dtype=torch.int64)
8084        helper([100, 100], 0, 2, dtype=torch.bool)
8085
8086        # Test random_
8087        for dtype in [torch.bool, torch.int8, torch.uint8, torch.int32, torch.float16, torch.float32]:
8088            x = torch.empty(10, 10, dtype=dtype, device='mps')
8089            x.random_()
8090            self.assertNotEqual(x.max().item(), 0)
8091
8092    # Test exponential
8093    def test_exponential(self):
8094        def helper(shape, lamda, dtype=torch.float32):
8095
8096            mps_out = torch.zeros(shape, device='mps', dtype=dtype)
8097            mps_out.exponential_(lamda)
8098
8099            print(mps_out.to('cpu').float().mean(), 1 / lamda)
8100            print(mps_out.to('cpu').float().std() ** 2, 1 / (lamda**2))
8101
8102        for dtype in [torch.float32, torch.float16]:
8103            helper([100, 100], 2, dtype)
8104            helper([100, 100], 1, dtype)
8105            helper([100, 100], 3, dtype)
8106            helper([100, 100], 0.5, dtype)
8107
8108    def test_exponential_1(self):
8109        rate = torch.randn(5, 5).abs().requires_grad_()
8110        rate_1d = torch.randn(1).abs().requires_grad_()
8111        self.assertEqual(Exponential(rate).sample().size(), (5, 5))
8112        self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
8113        self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
8114        self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
8115        self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
8116        self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
8117
8118    # Test add
8119    def test_add_sub(self):
8120        def helper(shape, alpha, op_name, inplace):
8121            if op_name == "add":
8122                op = torch.Tensor.add_ if inplace else torch.add
8123            elif op_name == "sub":
8124                op = torch.Tensor.sub_ if inplace else torch.sub
8125
8126            for dtype in [torch.float16, torch.float32]:
8127                cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
8128                mps_x = cpu_x.detach().clone().to('mps')
8129
8130                cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
8131                mps_y = cpu_y.detach().clone().to('mps')
8132
8133                cpu_out = op(cpu_x, cpu_y, alpha=alpha)
8134                mps_out = op(mps_x, mps_y, alpha=alpha)
8135                # fp16 isn't accurate when alpha is passed
8136                # TODO: remove or fix 'tol' when we fix problems with fp16
8137                tol = 2e-3 if dtype is torch.float16 else None
8138                self.assertEqual(mps_out, cpu_out, rtol=tol, atol=tol)
8139                if not (cpu_y.shape != () and inplace):  # in-place output cannot be broadcasted.
8140                    # create a scalar tensor
8141                    cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8142                    mps_s = cpu_s.detach().clone().to('mps')
8143                    # primary tensor is scalar
8144                    self.assertEqual(op(cpu_s, cpu_y), op(mps_s, mps_y))
8145                # create a scalar tensor
8146                cpu_s = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8147                mps_s = cpu_s.detach().clone().to('mps')
8148                # secondary tensor is scalar
8149                self.assertEqual(op(cpu_x, cpu_s), op(mps_x, mps_s), rtol=tol, atol=tol)
8150
8151
8152        for op_name, inplace in product(["add", "sub"], [True, False]):
8153            helper((), 0.0, op_name, inplace)
8154            helper((2, 8, 4, 5), 0.0, op_name, inplace)
8155            helper((2, 8, 4, 5), 0.1, op_name, inplace)
8156            helper((2, 8, 4, 5), 1.0, op_name, inplace)
8157            helper((2, 8, 3, 5), 0.1, op_name, inplace)
8158            helper((2, 8, 3, 5), 0.2, op_name, inplace)
8159
8160    # Test add
8161    def test_add_scalars(self):
8162        def helper(alpha):
8163            for dtype in [torch.float16, torch.float32]:
8164                cpu_x = torch.tensor(2.3, device='cpu', dtype=dtype, requires_grad=False)
8165                x = cpu_x.detach().clone().to('mps')
8166
8167                cpu_y = torch.tensor(3.4, device='cpu', dtype=dtype, requires_grad=False)
8168                y = cpu_y.detach().clone().to('mps')
8169
8170                cpu_out = torch.add(cpu_x, cpu_y, alpha=alpha)
8171                out = torch.add(x, y, alpha=alpha)
8172                # fp16 isn't accurate when alpha is passed
8173                tol = 1e-3 if dtype is torch.float16 else None
8174                self.assertEqual(out, cpu_out, rtol=tol, atol=tol)
8175
8176        helper(1.0)
8177        helper(0.0)
8178        helper(0.1)
8179        helper(0.2)
8180
8181        # Test int32 tensor + int64 scalar add
8182        # see https://github.com/pytorch/pytorch/issues/79835#issuecomment-1164984534
8183        x = torch.ones(4, dtype=torch.int32, device='mps')
8184        self.assertEqual(x + 1, torch.full((4,), 2, dtype=torch.int32, device='mps'))
8185        self.assertTrue(torch.equal(x + 1.5, torch.full((4,), 2.5, device='mps')))
8186
8187    def test_types_binary_op(self):
8188        # Float * Bool
8189        cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
8190        mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
8191        self.assertEqual(cpu_x, mps_x)
8192        # Float * Int64
8193        cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
8194        mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
8195        self.assertEqual(cpu_y, mps_y)
8196
8197    def test_unary_ops(self):
8198        def helper(shape, op):
8199            for dtypef in [torch.float32]:
8200                cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8201                mps_x = cpu_x.detach().clone().to('mps')
8202                self.assertEqual(op(cpu_x), op(mps_x))
8203
8204            for dtypei in [torch.int32, torch.int16]:
8205                cpu_x = torch.randint(0, 1000, shape, device='cpu', dtype=dtypei, requires_grad=False)
8206                mps_x = cpu_x.to('mps')
8207                self.assertEqual(op(cpu_x), op(mps_x), rtol=1e-4, atol=1e-4)
8208            # test slice
8209            for dtypef in [torch.float32]:
8210                cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8211                mps_x = cpu_x.detach().clone().to('mps')
8212                cpu_slice = cpu_x[:, ::2, :, :]
8213                mps_slice = mps_x[:, ::2, :, :]
8214                self.assertEqual(op(cpu_slice), op(mps_slice))
8215            # test view
8216            for dtypef in [torch.float32]:
8217                cpu_x = torch.randn(shape, device='cpu', dtype=dtypef, requires_grad=False)
8218                mps_x = cpu_x.detach().clone().to('mps')
8219                # create view of tensor by reducing the 3rd and 4th dimension
8220                combined_dim = shape[-1] * shape[-2]
8221                reshaped_dims = list(shape[:-2]) + [combined_dim]
8222                cpu_view = cpu_x.view(*reshaped_dims)
8223                mps_view = mps_x.view(*reshaped_dims)
8224                self.assertEqual(op(cpu_view), op(mps_view))
8225
8226        helper((2, 8, 4, 5), torch.exp)
8227        helper((2, 8, 3, 5), torch.exp2)
8228        helper((2, 8, 3, 5), torch.expm1)
8229        helper((2, 8, 3, 5), torch.log)
8230        helper((2, 8, 3, 5), torch.cos)
8231        helper((2, 8, 3, 5), torch.erfinv)
8232
8233
8234    def test_non_dense_in_storage_unary_ops(self):
8235        def helper(op):
8236            for dtypef in [torch.float32]:
8237                cpu_x = torch.randn(100, device='cpu', dtype=dtypef, requires_grad=False)
8238                mps_x = cpu_x.detach().clone().to('mps')
8239                self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]))
8240
8241            for dtypei in [torch.int32, torch.int16, torch.int8]:
8242                cpu_x = torch.randint(127, device='cpu', size=(100,), dtype=dtypei, requires_grad=False)
8243                mps_x = cpu_x.to('mps')
8244                self.assertEqual(op(cpu_x[::2]), op(mps_x[::2]), rtol=1e-4, atol=1e-4)
8245
8246        helper(torch.exp)
8247        helper(torch.exp2)
8248        helper(torch.expm1)
8249        helper(torch.log)
8250        helper(torch.cos)
8251
8252    def test_unary_ops_storage_offset_strided(self):
8253        def helper(shape, op, inplace, dtype=torch.float32):
8254            # test in-place with storage_offset
8255            cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8256            mps_x = cpu_x.detach().clone().to('mps')
8257            y = op(mps_x[1])
8258            cpu_y = op(cpu_x[1])
8259            self.assertEqual(y, cpu_y)
8260
8261
8262            # See https://github.com/pytorch/pytorch/issues/100764
8263            if not inplace:
8264                cpu_x = torch.randn(shape, device='cpu', dtype=dtype)
8265                mps_x = cpu_x.detach().clone().to('mps')
8266                cpu_y = torch.empty(shape, device='cpu', dtype=dtype).t()
8267                mps_y = cpu_y.detach().clone().to('mps')
8268                op(cpu_x, out=cpu_y)
8269                op(mps_x, out=mps_y)
8270                self.assertEqual(mps_y, cpu_y)
8271
8272
8273        helper((5, 5), torch.exp, False)
8274        helper((5, 5), torch.cos, False)
8275        helper((5, 5), torch.neg, False)
8276        helper((5, 5), torch.tanh, False)
8277        helper((5, 5), torch.tanh_, True)
8278
8279    def test_atan2(self):
8280        def helper(shape):
8281            input_cpu = torch.randn(shape)
8282            input_mps = input_cpu.detach().clone().to("mps")
8283
8284            other_cpu = torch.randn(shape)
8285            other_mps = other_cpu.detach().clone().to("mps")
8286
8287            atan2_cpu = torch.atan2(input_cpu, other_cpu)
8288            atan2_mps = torch.atan2(input_mps, other_mps)
8289
8290            self.assertEqual(atan2_cpu, atan2_mps.to("cpu"))
8291
8292        helper(4)
8293        helper(10000)
8294        helper((10000, 40))
8295
8296    def test_multinomial(self):
8297        # Test with num_dist = 1
8298        def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
8299            cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
8300            prob_tensor = cpu_prob_tensor.detach().clone().to('mps')
8301
8302            mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
8303            if (not replacement):
8304                print(mps_out.to('cpu'))
8305            else:
8306                # Compare "real" with theoretical values
8307                print(mps_out.to('cpu').float().mean(), compare_mean)
8308                print(mps_out.to('cpu').float().std() ** 2, compare_var)
8309
8310        # TODO: Add tests for data types
8311        helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4) / 2, (12.5 - 3.5 ** 2), 100000)
8312        helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8313        helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8314        helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4) / 5, (6 - 2 * 2), 10000)
8315        helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)
8316
8317    def test_cumsum_dim_check(self):
8318        x = torch.rand((3, 3), device="mps")
8319        self.assertEqual(x.cumsum(1), x.cumsum(-1))
8320        self.assertEqual(x.cumsum(0), x.cumsum(-2))
8321        self.assertRaises(IndexError, lambda: x.cumsum(2))
8322        self.assertRaises(IndexError, lambda: x.cumsum(-3))
8323
8324    def test_cumprod_dim_check(self):
8325        x = torch.rand((3, 3), device="mps")
8326        self.assertEqual(x.cumprod(1), x.cumprod(-1))
8327        self.assertEqual(x.cumprod(0), x.cumprod(-2))
8328        self.assertRaises(IndexError, lambda: x.cumprod(2))
8329        self.assertRaises(IndexError, lambda: x.cumprod(-3))
8330
8331class TestLogical(TestCaseMPS):
8332    def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
8333        return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
8334
8335    def test_logical_not(self):
8336        def helper(x):
8337            cpu_x = x
8338            x = cpu_x.detach().clone().to('mps')
8339
8340            result = torch.logical_not(x)
8341            result_cpu = torch.logical_not(cpu_x)
8342
8343            self.assertEqual(result, result_cpu)
8344
8345        helper(self._wrap_tensor([1, 1, 0, 0]))
8346        helper(self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True))
8347        helper(self._wrap_tensor([True, True, False, False]))
8348        helper(self._wrap_tensor(1))
8349        helper(self._wrap_tensor(0))
8350        helper(self._wrap_tensor(True))
8351        helper(self._wrap_tensor(False))
8352
8353    def test_logical_and(self):
8354        def helper(x, other):
8355            cpu_x = x
8356            x = cpu_x.detach().clone().to('mps')
8357
8358            cpu_other = other
8359            other = cpu_other.detach().clone().to('mps')
8360
8361            result = torch.logical_and(x, other)
8362            result_cpu = torch.logical_and(cpu_x, cpu_other)
8363            self.assertEqual(result, result_cpu)
8364
8365        helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8366        helper(
8367            self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8368            self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8369        )
8370        helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8371        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8372        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8373        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8374        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8375
8376    def test_logical_or(self):
8377        def helper(x, other):
8378            cpu_x = x
8379            x = cpu_x.detach().clone().to('mps')
8380
8381            cpu_other = other
8382            other = cpu_other.detach().clone().to('mps')
8383
8384            result = torch.logical_or(x, other)
8385            result_cpu = torch.logical_or(cpu_x, cpu_other)
8386
8387            self.assertEqual(result, result_cpu)
8388
8389        helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8390        helper(
8391            self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8392            self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8393        )
8394        helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8395        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8396        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8397        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8398        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8399
8400    def test_logical_xor(self):
8401        def helper(x, other):
8402            cpu_x = x
8403            x = cpu_x.detach().clone().to('mps')
8404
8405            cpu_other = other
8406            other = cpu_other.detach().clone().to('mps')
8407
8408            result = torch.logical_xor(x, other)
8409            result_cpu = torch.logical_xor(cpu_x, cpu_other)
8410
8411            self.assertEqual(result, result_cpu)
8412
8413        helper(self._wrap_tensor([1, 1, 0, 0]), self._wrap_tensor([1, 0, 0, 1]))
8414        helper(
8415            self._wrap_tensor([1, 1, 0, 0], dtype=torch.float, requires_grad=True),
8416            self._wrap_tensor([1, 0, 0, 1], dtype=torch.float)
8417        )
8418        helper(self._wrap_tensor([True, True, False, False]), self._wrap_tensor([True, False, False, True]))
8419        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(1))
8420        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(0))
8421        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
8422        helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
8423
8424    def test_min_max(self):
8425        def helper(dtype):
8426            for _ in range(10):
8427                if dtype == torch.float32 or dtype == torch.float16:
8428                    x = torch.randn((30, 15), device='mps', dtype=dtype)
8429                else:
8430                    x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
8431                x_cpu = x.to("cpu")
8432
8433                y = x.max()
8434                y_cpu = x_cpu.max()
8435                self.assertEqual(y, y_cpu)
8436
8437                z = x.min()
8438                z_cpu = x_cpu.min()
8439                self.assertEqual(z, z_cpu)
8440
8441        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
8442
8443    def test_min_max_nan_propagation(self):
8444        def helper(dtype):
8445            cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
8446            mps_x = cpu_x.detach().clone().to('mps')
8447
8448            cpu_max = torch.max(cpu_x)
8449            mps_max = torch.max(mps_x).to('cpu')
8450
8451            cpu_amax = torch.amax(cpu_x)
8452            mps_amax = torch.amax(mps_x).to('cpu')
8453
8454            cpu_min = torch.min(cpu_x)
8455            mps_min = torch.min(mps_x).to('cpu')
8456
8457            cpu_amin = torch.amin(cpu_x)
8458            mps_amin = torch.amin(mps_x).to('cpu')
8459
8460            self.assertEqual(cpu_max, mps_max)
8461            self.assertEqual(cpu_amax, mps_amax)
8462            self.assertEqual(cpu_min, mps_min)
8463            self.assertEqual(cpu_amin, mps_amin)
8464        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
8465
8466    def test_isin(self):
8467        def helper(dtype):
8468            shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),
8469                      ([5], [10]), ([0], [5]), ([5], [0])]
8470            for shape_tuple in shapes:
8471                for inverted in [True, False]:
8472                    if dtype.is_floating_point:
8473                        # Half is not supported for CPU isin. Compute reference in FP32
8474                        A = torch.randn(size=shape_tuple[0], device='cpu', dtype=torch.float32)
8475                        B = torch.randn(size=shape_tuple[1], device='cpu', dtype=torch.float32)
8476                    else:
8477                        A = torch.randint(0, 100, size=shape_tuple[0], device='cpu', dtype=dtype)
8478                        B = torch.randint(0, 100, size=shape_tuple[1], device='cpu', dtype=dtype)
8479
8480                    A_mps = A.clone().detach().to('mps')
8481                    B_mps = B.clone().detach().to('mps')
8482
8483                    cpu_ref = torch.isin(A, B, invert=inverted)
8484                    if dtype in [torch.float16, torch.bfloat16]:
8485                        cpu_ref.type(dtype)
8486
8487                    mps_out = torch.isin(A_mps, B_mps, invert=inverted)
8488                    self.assertEqual(mps_out, cpu_ref)
8489
8490        dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8]
8491        if product_version < 14.0:
8492            # Int types expected to fail on MacOS < 14.0
8493            dtypes = [torch.float32, torch.float16, torch.bfloat16]
8494
8495        [helper(dtype) for dtype in dtypes]
8496
8497    def test_isin_asserts(self):
8498        A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8499        B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16)
8500        with self.assertRaisesRegex(RuntimeError, 'Expected elements.dtype()*'):
8501            out = torch.isin(A, B)
8502
8503
8504        C = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
8505        D = torch.randn(size=[1, 4], device='cpu', dtype=torch.float32)
8506        with self.assertRaisesRegex(RuntimeError, 'Expected elements.is_mps()*'):
8507            out = torch.isin(C, D)
8508
8509class TestSmoothL1Loss(TestCaseMPS):
8510
8511    def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
8512        # CPU
8513        input_cpu = torch.randn(4, 7, requires_grad=requires_grad)
8514        target_cpu = torch.randn(4, 7)
8515
8516        # MPS
8517        input_mps = input_cpu.detach().clone().to('mps').requires_grad_()
8518        target_mps = target_cpu.detach().clone().to('mps')
8519
8520        smooth_l1_loss_cpu = F.smooth_l1_loss(input_cpu, target_cpu, beta=1.0, reduction=reduction)
8521        smooth_l1_loss_mps = F.smooth_l1_loss(input_mps, target_mps, beta=1.0, reduction=reduction)
8522
8523        self.assertEqual(smooth_l1_loss_cpu, smooth_l1_loss_mps)
8524
8525        if requires_grad:
8526            smooth_l1_loss_cpu.backward()
8527            smooth_l1_loss_mps.backward()
8528            self.assertEqual(input_cpu.grad, input_mps.grad.to("cpu"))
8529
8530        return smooth_l1_loss_cpu, smooth_l1_loss_mps
8531
8532    def test_smooth_l1_loss_reduction_none(self):
8533        self._smooth_l1_loss_helper(reduction="none")
8534
8535    def test_smooth_l1_loss_reduction_mean(self):
8536        self._smooth_l1_loss_helper(reduction="mean")
8537
8538    def test_smooth_l1_loss_reduction_sum(self):
8539        self._smooth_l1_loss_helper(reduction="sum")
8540
8541    def test_smooth_l1_loss_reduction_mean_backward(self):
8542        self._smooth_l1_loss_helper(reduction="mean", requires_grad=True)
8543
8544    def test_smooth_l1_loss_reduction_mean_sum_backward(self):
8545        self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
8546
8547class TestNLLLoss(TestCaseMPS):
8548    def test_nll_loss_mismatched_batch(self, device='mps'):
8549        x = torch.randn((10, 3), requires_grad=True, device=device)
8550        # t should have size (10,)
8551        t = torch.zeros((3,), dtype=torch.int64, device=device)
8552        with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
8553            F.nll_loss(x, t)
8554
8555    def test_nll_loss_out_of_bounds_ignore_index(self):
8556
8557        def test_nll_loss_out_of_bounds_ignore_index_helper(device):
8558            output = []
8559            x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8560                             0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8561            t1 = torch.tensor([0, 1, 255, 0, 1, 2], dtype=torch.int64, device=device)
8562            t2 = torch.tensor([0, 1, 1, 0, -100, 2], dtype=torch.int64, device=device)
8563            for reduction in ['mean', 'none']:
8564                # out of bound ignore_index
8565                output.append(F.nll_loss(x, t1, ignore_index=255, reduction=reduction))
8566                # default ignore_index
8567                output.append(F.nll_loss(x, t2, reduction=reduction))
8568            return output
8569
8570        output_cpu = test_nll_loss_out_of_bounds_ignore_index_helper(device='cpu')
8571        output_mps = test_nll_loss_out_of_bounds_ignore_index_helper(device='mps')
8572
8573        for cpu, mps in zip(output_cpu, output_mps):
8574            self.assertEqual(cpu, mps)
8575
8576    def test_nll_loss_invalid_target_dim(self):
8577
8578        def _test_nll_loss_invalid_target_dim(device):
8579            output = []
8580            x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8581                             0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8582            t = torch.zeros((6, 2), dtype=torch.int64, device=device)
8583            with self.assertRaisesRegex(RuntimeError, "1D target tensor expected"):
8584                F.nll_loss(x, t)
8585
8586        _test_nll_loss_invalid_target_dim(device='cpu')
8587        _test_nll_loss_invalid_target_dim(device='mps')
8588
8589    def test_nll_loss_invalid_weights(self):
8590
8591        def _test_nll_loss_invalid_weights(device):
8592            x = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1], [
8593                             0.3, 0.5, 0.2], [0.1, 0.7, 0.2], [0.4, 0.5, 0.1]], device=device)
8594            t = torch.tensor([0, 1, 2, 1, 1, 2], dtype=torch.int64, device=device)
8595            invalid_weights = [
8596                torch.zeros(4, device=device),
8597                torch.zeros((1, 3), device=device),
8598            ]
8599            msg = "weight tensor should be defined either for all 3 classes or no classes"
8600            for weight in invalid_weights:
8601                with self.assertRaisesRegex(RuntimeError, msg):
8602                    F.nll_loss(x, t, weight=weight)
8603
8604        _test_nll_loss_invalid_weights(device='cpu')
8605        _test_nll_loss_invalid_weights(device='mps')
8606
8607    def _nll_loss_helper(self, input_size, reduction, expected):
8608
8609        # CPU
8610        input = torch.rand(input_size, requires_grad=True, device='cpu')
8611        num_channels = input_size[1]
8612        target_size = (input_size[0], ) + tuple(input_size[2:])
8613        target = torch.randint(num_channels, target_size, device='cpu')
8614        weights = torch.randn(num_channels)
8615
8616        # MPS
8617        input_mps = input.detach().clone().to('mps').requires_grad_()
8618        target_mps = target.detach().clone().to('mps')
8619        weights_mps = weights.to("mps")
8620
8621        output_cpu = F.nll_loss(input, target, weight=weights, reduction=reduction)
8622        output_mps = F.nll_loss(input_mps, target_mps, weight=weights_mps, reduction=reduction)
8623        self.assertEqual(output_cpu, output_mps.to('cpu'))
8624
8625        output_cpu.sum().backward()
8626        output_mps.sum().backward()
8627        self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8628
8629    def _nll_loss_1d_helper(self, input_size, reduction):
8630
8631        # CPU
8632        input = torch.rand(input_size, requires_grad=True, device='cpu')
8633        num_channels = input_size[0]
8634        target = torch.randint(num_channels, [], device='cpu')
8635
8636        # MPS
8637        input_mps = input.detach().clone().to('mps').requires_grad_()
8638        target_mps = target.detach().clone().to('mps')
8639
8640        output_cpu = F.nll_loss(input, target, reduction=reduction)
8641        output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
8642        self.assertEqual(output_cpu, output_mps.to('cpu'))
8643
8644        output_cpu.sum().backward()
8645        output_mps.sum().backward()
8646        self.assertEqual(input.grad, input_mps.grad.to('cpu'))
8647
8648    def test_nll_loss_1d(self, device='cpu'):
8649        self._nll_loss_1d_helper([10], "none")
8650        self._nll_loss_1d_helper([10], "mean")
8651        self._nll_loss_1d_helper([10], "sum")
8652
8653    def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
8654        self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
8655        self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))
8656        self._nll_loss_helper([2, 3, 1, 7], "none", torch.empty([2, 1, 7], device=device))
8657        self._nll_loss_helper([2, 3, 5, 1], "none", torch.empty([2, 5, 1], device=device))
8658        self._nll_loss_helper([2, 3, 5, 7, 1], "none", torch.empty([2, 5, 7, 1], device=device))
8659
8660    def test_nll_loss_empty_tensor_reduction_mean(self, device='cpu'):
8661        nan = torch.tensor(float('nan'), device=device)
8662        self._nll_loss_helper([1, 3], "mean", nan)
8663        self._nll_loss_helper([1, 3, 5, 7], "mean", nan)
8664        self._nll_loss_helper([2, 3, 1, 7], "mean", nan)
8665        self._nll_loss_helper([2, 3, 5, 1], "mean", nan)
8666        self._nll_loss_helper([2, 3, 5, 7, 1], "mean", nan)
8667
8668    def test_nll_loss_empty_tensor_reduction_sum(self, device='cpu'):
8669        zero = torch.tensor(0, device=device)
8670        self._nll_loss_helper([1, 3], "sum", zero)
8671        self._nll_loss_helper([1, 3, 5, 7], "sum", zero)
8672        self._nll_loss_helper([2, 3, 1, 7], "sum", zero)
8673        self._nll_loss_helper([2, 3, 5, 1], "sum", zero)
8674        self._nll_loss_helper([2, 3, 5, 7, 1], "sum", zero)
8675
8676    def test_nll_loss_byte_target_matches_long(self, device='cpu'):
8677        N, C = 10, 4
8678        input = torch.randn(N, C, device=device, requires_grad=True)
8679        target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
8680
8681        def compute_result_and_gradient(reduction, target_dtype):
8682            result, grad = {}, {}
8683            for dev in ['cpu', 'mps']:
8684                input_dev = input.to(dev)
8685                input_ = input_dev.detach()
8686                input_.requires_grad_()
8687
8688                target_dev = target.to(dev)
8689
8690                prob = F.log_softmax(input_, dim=-1)
8691                loss = nn.NLLLoss(reduction=reduction)
8692                result[dev] = loss(prob, target_dev.to(target_dtype))
8693                result[dev].sum().backward()
8694                grad[dev] = input_.grad
8695
8696            return result, grad
8697
8698        for reduction in ["none", "mean", "sum"]:
8699            result_long, grad_long = compute_result_and_gradient(reduction, torch.long)
8700            result_byte, grad_byte = compute_result_and_gradient(reduction, torch.uint8)
8701
8702            self.assertEqual(result_long['mps'].to('cpu'), result_long['cpu'])
8703            self.assertEqual(grad_long['mps'].to('cpu'), grad_long['cpu'])
8704
8705class TestTopK(TestCase):
8706    def _test_topk(self, shape, largest):
8707        cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
8708        x = cpu_x.detach().clone().to('mps')
8709        if isinstance(shape, tuple):
8710            for curr_dim, dim_size in enumerate(shape):
8711                for k in range(1, dim_size + 1):
8712                    topk_values, topk_indices = torch.topk(x, k, dim=curr_dim, largest=largest)
8713                    topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=curr_dim, largest=largest)
8714                    self.assertEqual(topk_values, topk_values_cpu)
8715                    self.assertEqual(topk_indices, topk_indices_cpu)
8716        else:
8717            for k in range(1, shape):
8718                topk_values, topk_indices = torch.topk(x, k, dim=0, largest=largest)
8719                topk_values_cpu, topk_indices_cpu = torch.topk(cpu_x, k, dim=0, largest=largest)
8720                self.assertEqual(topk_values, topk_values_cpu)
8721                self.assertEqual(topk_indices, topk_indices_cpu)
8722
8723    def test_topk(self):
8724        largest_vals = [True, False]
8725        shapes = [
8726            # Zero Element Tensors
8727            0,
8728            (1, 0),
8729            (0, 1),
8730            (1, 0, 1),
8731            # Multiple Element Tensors
8732            1,
8733            2,
8734            (5, 1),
8735            (1, 5),
8736            (5, 9, 7, 4),
8737        ]
8738
8739        for shape in shapes:
8740            for largest_val in largest_vals:
8741                with self.subTest(shape=shape, largest_val=largest_val):
8742                    self._test_topk(shape, largest_val)
8743
8744class TestNNMPS(NNTestCase):
8745
8746    def _create_basic_net(self):
8747        class Layer(nn.Module):
8748            def __init__(self) -> None:
8749                super().__init__()
8750                self.layer_dummy_param = Parameter(torch.empty(3, 5))
8751                self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7))
8752
8753        class Net(nn.Module):
8754            def __init__(self) -> None:
8755                super().__init__()
8756                self.l1 = Layer()
8757                self.dummy_param = Parameter(torch.empty(3, 5))
8758                self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1))
8759
8760        l = Layer()
8761        n = Net()
8762        s = nn.Sequential(n, n)
8763
8764        return l, n, s
8765
8766    def test_requires_grad_(self):
8767        m = self._create_basic_net()[-1]
8768        assert len(list(m.buffers())) > 0, 'invalid test'
8769        assert all(not b.requires_grad for b in m.buffers()) > 0, 'invalid test'
8770        assert len(list(m.parameters())) > 0, 'invalid test'
8771        assert all(p.requires_grad for p in m.parameters()) > 0, 'invalid test'
8772        for requires_grad in (False, True):
8773            self.assertIs(m.requires_grad_(requires_grad), m)
8774            for p in m.parameters():
8775                self.assertEqual(p.requires_grad, requires_grad)
8776            for b in m.buffers():
8777                self.assertFalse(b.requires_grad)
8778
8779    def test_module_backcompat(self):
8780        from torch.serialization import SourceChangeWarning
8781        path = download_file('https://download.pytorch.org/test_data/linear.pt')
8782        with warnings.catch_warnings():
8783            warnings.simplefilter('ignore', SourceChangeWarning)
8784            m = torch.load(path)
8785        input = torch.randn(2, 3, dtype=torch.float)
8786        self.assertEqual(m(input).size(), (2, 5))
8787
8788    def test_conv_backcompat(self):
8789        from torch.serialization import SourceChangeWarning
8790        # This file was generated by running on PyTorch 1.0.1 on Python 2:
8791        #
8792        #     import torch
8793        #     from torch import nn
8794        #     m = nn.Conv2d(1, 1, 1)
8795        #     torch.save(m, 'legacy_conv2d.pt')
8796        #
8797        # NB: This Pickle also contains some Unicode data!
8798        path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
8799        with warnings.catch_warnings():
8800            warnings.simplefilter('ignore', SourceChangeWarning)
8801            m = torch.load(path, encoding='utf-8')
8802        input = torch.randn((1, 1, 1, 1), dtype=torch.float)
8803        self.assertEqual(m(input).size(), (1, 1, 1, 1))
8804
8805    def test_conv_expand(self):
8806        device = 'mps'
8807        input_ = torch.rand(2, 3, 16, 16, device=device)
8808        kernel = torch.rand(1, 1, 3, 11, device=device)
8809        tmp_kernel = kernel.expand(-1, 3, -1, -1)
8810        output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
8811
8812    # The test should not crash
8813    def test_permute(self):
8814        M_cpu = torch.randn(5, 5)
8815        M_mps = M_cpu.to('mps')
8816
8817        output_cpu = M_cpu.permute(1, 0)
8818        output_mps = M_mps.permute(1, 0)
8819
8820        self.assertEqual(output_cpu, output_mps)
8821        self.assertEqual(output_cpu.size(), output_mps.size())
8822
8823    # Printing of non_contiguous should not crash
8824    def test_print_non_contiguous(self):
8825        print(torch.ones(100, 100, device='mps').nonzero())
8826        print(torch.ones(100, 100, device='mps').nonzero().contiguous())
8827
8828    def test_zero_grad(self):
8829        i = torch.randn(2, 5, requires_grad=True)
8830        module = nn.Linear(5, 5)
8831        for p in module.parameters():
8832            p.requires_grad = False
8833        module.zero_grad()
8834
8835        module.weight.requires_grad = True
8836        module.zero_grad()
8837        self.assertIsNone(module.weight.grad)  # uninitialized grad
8838
8839        module(i).sum().backward()
8840        self.assertIsNotNone(module.weight.grad)
8841        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8842        module.zero_grad()
8843        self.assertIsNone(module.weight.grad)
8844
8845        module.bias.requires_grad = True
8846        module.zero_grad()
8847        self.assertIsNone(module.weight.grad)
8848        self.assertIsNone(module.bias.grad)
8849        module(i).sum().backward()
8850        self.assertIsNotNone(module.weight.grad)
8851        self.assertIsNotNone(module.bias.grad)
8852        self.assertGreater(module.weight.grad.data.abs().sum(), 0)
8853        self.assertGreater(module.bias.grad.data.abs().sum(), 0)
8854
8855        # Force set to zeros.
8856        module.zero_grad(set_to_none=False)
8857        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
8858        self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
8859
8860        module.zero_grad()
8861        self.assertIsNone(module.weight.grad)
8862        self.assertIsNone(module.bias.grad)
8863
8864
8865    def test_no_grad(self):
8866        for dtype in [torch.bfloat16, torch.float, torch.double]:
8867            module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype)
8868            input = torch.randn(1, 2, 10, 10).to(dtype)
8869            x = input
8870            y = input.clone()
8871
8872            output = module(x)
8873            self.assertTrue(output.requires_grad)
8874            output.backward(torch.ones(1, 5, 10, 10))
8875
8876            with torch.no_grad():
8877                output2 = module(y)
8878                self.assertFalse(output2.requires_grad)
8879                self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10)))
8880
8881    def test_invalid_conv1d(self):
8882        for dtype in [torch.bfloat16, torch.float, torch.double]:
8883            module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype)
8884            input = torch.randn(1, 3, 4).to(dtype)
8885            with self.assertRaisesRegex(RuntimeError,
8886                                        r'Calculated padded input size per channel: \(4\). ' +
8887                                        r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'):
8888                module(input)
8889
8890            # Negative stride check
8891            module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype)
8892            input = torch.randn(1, 3, 4).to(dtype)
8893            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8894                module(input)
8895
8896    def test_conv2d_discontiguous_weight(self):
8897        # Test for https://github.com/pytorch/pytorch/issues/55781
8898        x = torch.ones(64, 16, 16, 16)
8899        weight = torch.arange(0, 1.0, 1 / 2.0 ** 10).reshape(32, 16, 1, 2)[:, :, :, ::2]
8900        self.assertFalse(weight.is_contiguous())
8901        y = torch.nn.functional.conv2d(x, weight, None)
8902        if torch.backends.mkldnn.is_available():
8903            # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used
8904            with torch.backends.mkldnn.flags(enabled=False):
8905                y_ = torch.nn.functional.conv2d(x, weight, None)
8906                self.assertEqual(y, y_)
8907        self.assertEqual(y.sum(), 4186112.)
8908
8909    def test_invalid_conv2d(self):
8910        for dtype in [torch.bfloat16, torch.float, torch.double]:
8911            module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype)
8912            input = torch.empty(1, 1, 4, 4).to(dtype)
8913            self.assertRaises(RuntimeError, lambda: module(input))
8914
8915            module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True)
8916            input = torch.randn(1, 3, 1, 1)
8917            with self.assertRaisesRegex(RuntimeError,
8918                                        r'Calculated padded input size per channel: \(1 x 1\). ' +
8919                                        r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'):
8920                module(input)
8921
8922            # Negative stride check
8923            module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype)
8924            input = torch.randn(1, 3, 4, 4).to(dtype)
8925            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8926                module(input)
8927
8928            # Zero stride check
8929            module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True).to(dtype)
8930            input = torch.randn(1, 3, 4, 4).to(dtype)
8931            with self.assertRaisesRegex(RuntimeError, 'non-positive stride is not supported'):
8932                module(input)
8933
8934            # Input and weights on different devices
8935            self.assertRaisesRegex(RuntimeError,
8936                                   'must be on the same device',
8937                                   lambda: torch.conv2d(torch.rand(1, 3, 32, 32), torch.rand(1, 3, 3, 3, device='mps')))
8938            self.assertRaisesRegex(RuntimeError,
8939                                   'Input type \\(MPSFloatType\\) and weight type \\(torch\\.FloatTensor\\) should be the same',
8940                                   lambda: torch.conv2d(torch.rand(1, 3, 32, 32, device='mps'), torch.rand(1, 3, 3, 3)))
8941
8942
8943    def test_conv2d_valid_padding(self, device='mps'):
8944        # Test F.conv2d padding='valid' is the same as no padding
8945        x = torch.rand(1, 1, 1, 10, device=device).to(torch.float)
8946        y = torch.rand(1, 1, 1, 4, device=device).to(torch.float)
8947
8948        expect = F.conv2d(x, y)
8949        actual = F.conv2d(x, y, padding='valid')
8950        self.assertEqual(expect.to('cpu'), actual.to('cpu'))
8951
8952    def test_conv2d_backward_collision(self):
8953        # Test for https://github.com/pytorch/pytorch/issues/112998
8954        x = torch.rand(1, 1, 10, 10, device="mps", requires_grad=True)
8955        m1 = nn.Conv2d(1, 1, 3, stride=2, padding=1).to("mps")
8956        m2 = nn.Conv2d(1, 1, 4, stride=2, padding=1).to("mps")
8957        y1, y2 = m1(x), m2(x)
8958        self.assertEqual(y1.shape, y2.shape)
8959        y1.sum().backward()
8960        # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8961        y2.sum().backward()
8962
8963    @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
8964    def test_conv3d_backward_collision(self):
8965        # Conv3D is only available from MacOS 13.2 onwards
8966        x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
8967        m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
8968        m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
8969        y1, y2 = m1(x), m2(x)
8970        self.assertEqual(y1.shape, y2.shape)
8971        y1.sum().backward()
8972        # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
8973        y2.sum().backward()
8974
8975    def test_gemm_permute_transpose(self):
8976        batch_size = 32
8977        n = 20
8978        hidden = 768
8979        num_attention_heads = 12
8980        attention_head_size = hidden // num_attention_heads
8981
8982        def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
8983            new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
8984            x = x.view(new_x_shape)
8985            return x.permute(0, 2, 1, 3)
8986
8987        def attention2(key, *, workaround=False, device):
8988            key = transpose_for_scores(key)
8989            res = key.transpose(-1, -2)
8990            return res
8991
8992        A = torch.randn(batch_size, n, hidden)
8993        A_mps = A.detach().clone().to("mps")
8994
8995        r1 = attention2(A, device="cpu")
8996        r2 = attention2(A_mps, device="mps")
8997
8998        r2_cpu = r2.to("cpu")
8999        self.assertEqual(r1, r2_cpu)
9000
9001    def test_group_norm_backward(self, device='mps'):
9002        # See https://github.com/pytorch/pytorch/issues/88331 for more detail
9003        shape = [1, 4, 16, 16]
9004        x = torch.full(shape, 7.0, device=device)
9005
9006        target = torch.ones((1, 3, 128, 128), device=device)
9007
9008        conv_in = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
9009        conv_out = nn.Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), device=device)
9010        norm = nn.GroupNorm(32, 128, eps=1e-6, affine=True, device=device)
9011
9012        with torch.enable_grad():
9013            x = x.detach().requires_grad_()
9014            out = 5.5 * x
9015            out = conv_in(out)
9016            out = out + norm(out)
9017            out = out + norm(out)
9018            out = out + norm(out)
9019            out = F.interpolate(out, scale_factor=8.0, mode="nearest")
9020            out = norm(out)
9021            out = conv_out(out)
9022
9023            loss = (out - target).norm(dim=-1).sum()
9024            grad = -torch.autograd.grad(loss, x)[0]
9025            self.assertFalse(grad.detach().isnan().any().item(), 'NaN gradients returned by autograd')
9026
9027
9028    # def test_conv2d_same_padding(self, device='mps'):
9029        # x = torch.rand(1, 1, 10, 11, device=device)
9030        # y = torch.rand(1, 1, 4, 5, device=device)
9031        # expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :]
9032        # actual = F.conv2d(x, y, padding='same')
9033        # self.assertEqual(expect.to('cpu'), actual.to('cpu'))
9034
9035        # # With dilation
9036        # y = torch.rand(1, 1, 3, 4, device=device)
9037        # expect = F.conv2d(x, y, padding=(2, 3), dilation=2)
9038        # actual = F.conv2d(x, y, padding='same', dilation=2)
9039        # self.assertEqual(expect, actual)
9040
9041        # # Dilation with asymmetric padding
9042        # y = torch.rand(1, 1, 4, 4, device=device)
9043        # expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:]
9044        # actual = F.conv2d(x, y, padding='same', dilation=3)
9045        # self.assertEqual(expect, actual)
9046
9047
9048class TestPad(TestCaseMPS):
9049    def test_constant_pad(self):
9050        m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
9051        input_cpu = torch.randn(1, 16, 16, 16)
9052        input_mps = input_cpu.detach().clone().to("mps")
9053        r_cpu = m(input_cpu)
9054        r_mps = m(input_mps)
9055        self.assertEqual(r_cpu, r_mps.to("cpu"))
9056
9057        # Arbitrary input dimensions
9058        pad = (1, 1, 0, 0, 0, 0)
9059        value = 3.5
9060        input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
9061        input_mps = input_cpu.detach().clone().to("mps")
9062        r_cpu = F.pad(input_cpu, pad=pad, value=value)
9063        r_mps = F.pad(input_mps, pad=pad, value=value)
9064        self.assertEqual(r_cpu, r_mps.to("cpu"))
9065
9066    def test_circular_pad(self):
9067        # https://github.com/pytorch/pytorch/issues/80856
9068        k_cpu = torch.ones(3, 3, 9, 9)
9069        k_mps = k_cpu.detach().clone().to("mps")
9070
9071        x_cpu = torch.rand(1, 3, 32, 32)
9072        x_mps = x_cpu.detach().clone().to("mps")
9073
9074        x_pad_cpu = F.pad(x_cpu, (2, 2, 2, 2), mode='circular')
9075        x_pad_mps = F.pad(x_mps, (2, 2, 2, 2), mode='circular')
9076
9077        y_cpu = F.conv2d(x_pad_cpu, k_cpu)
9078        y_mps = F.conv2d(x_pad_mps, k_mps)
9079
9080        self.assertEqual(y_cpu, y_mps.cpu())
9081
9082    def test_constant_pad_4d_warning(self):
9083        inputCPU = torch.rand((1, 2, 2, 2, 1, 1))
9084        inputMPS = inputCPU.detach().clone().to('mps')
9085        outputCPU = F.pad(inputCPU, [0, 0, 0, 0, 0, 0, 1, 0])
9086        outputMPS = F.pad(inputMPS, [0, 0, 0, 0, 0, 0, 1, 0])
9087        self.assertEqual(outputCPU, outputMPS)
9088
9089    def test_pad(self):
9090        def helper(shape, padding, op, value=0):
9091            inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
9092            inputCPU.retain_grad()
9093            inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
9094
9095            if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
9096                padCriteria = op(padding, value)
9097            else:
9098                padCriteria = op(padding)
9099            outputCPU = padCriteria(inputCPU)
9100            outputMPS = padCriteria(inputMPS)
9101            self.assertEqual(outputCPU, outputMPS)
9102
9103            # backward pass (chose 0.6 just to have the grad_output != 1)
9104            outputCPU.backward(gradient=torch.full_like(outputCPU, 0.6))
9105            outputMPS.backward(gradient=torch.full_like(outputMPS, 0.6))
9106            self.assertEqual(inputCPU.grad, inputMPS.grad)
9107
9108        # 1D Padding
9109        helper((2, 4, 3), 2, nn.ReflectionPad1d)
9110        # verify if a change in shape of input would cause problems with graph caching
9111        helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
9112        # Replication 1D
9113        helper((2, 1, 6), 3, nn.ReplicationPad1d)
9114        # Constant Pad 1D
9115        helper((2, 3, 4), 2, nn.ConstantPad1d)
9116        # Constant Pad 1D with single dimension input
9117        helper((16), (1, 2), nn.ConstantPad1d)
9118
9119        # 2D Padding
9120        helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
9121        # verify if a change in shape of input would cause problems with graph caching
9122        helper((2, 4, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
9123        # this should make the padding (2, 2, 2, 2)
9124        helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
9125        # verify if a change in shape of padding would cause problems with graph caching
9126        helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
9127        # Constant Pad 2D
9128        helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
9129        # input size < pad size
9130        helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
9131        # pad dims < input dims
9132        helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
9133        # pad dims == input dims
9134        helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
9135        # input.numel() == 0 but output.numel() > 0
9136        helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
9137        # pad dims < input dims - 2
9138        helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
9139
9140        # 3D Padding
9141        helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
9142        # verify if a change in shape of padding would cause problems with graph caching
9143        helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
9144        # case where input_d == pad_front/back for ReplicationPad3d
9145        helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
9146        # Constant Pad 3D
9147        helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
9148        # input size < pad size
9149        helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
9150        # check the workaround for the right padding bug in Monterey
9151        helper((1, 2, 2, 2, 2), (0, 1), nn.ConstantPad3d)
9152
9153    def test_constant_pad_nd_preserves_memory_format(self):
9154        nchw_tensor = torch.rand((1, 2, 5, 3))
9155        nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
9156        self.assertTrue(nchw_padded.is_contiguous(memory_format=torch.contiguous_format))
9157
9158        nhwc_tensor = nchw_tensor.contiguous(memory_format=torch.channels_last)
9159        nhwc_padded = torch.constant_pad_nd(nhwc_tensor, [1, 2], 0.5)
9160        self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
9161
9162
9163class TestLinalgMPS(TestCaseMPS):
9164    def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
9165        dtype = t.dtype
9166        numpy_dtype = dtype
9167        alpha = 1.2 if alpha is None else alpha
9168        beta = 0.8 if beta is None else beta
9169        res1 = f(t, m, v, alpha=alpha, beta=beta)
9170        res2 = torch.full_like(res1, math.nan)
9171        if transpose_out:
9172            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
9173        f(t, m, v, alpha=alpha, beta=beta, out=res2)
9174        res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
9175        if beta != 0:
9176            res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9177        res3 = torch.from_numpy(res3).to(dtype)
9178        self.assertEqual(res1, res2)
9179        self.assertEqual(res1, res3)
9180
9181    def test_addmm(self, device="mps", dtype=torch.float32):
9182        M = torch.randn(10, 25, device=device).to(dtype)
9183        m1 = torch.randn(10, 50, device=device).to(dtype)
9184        m2 = torch.randn(50, 25, device=device).to(dtype)
9185        self._test_addmm_addmv(torch.addmm, M, m1, m2)
9186
9187        # Test beta=0, M=nan
9188        M = torch.full((10, 25), math.nan, device=device).to(dtype)
9189        m1 = torch.randn(10, 50, device=device).to(dtype)
9190        m2 = torch.randn(50, 25, device=device).to(dtype)
9191        self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
9192
9193        # Test transpose
9194        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
9195            def maybe_transpose(cond, m):
9196                if not cond:
9197                    return m
9198                return m.t().clone(memory_format=torch.contiguous_format).t()
9199
9200        M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
9201        m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
9202        m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
9203        self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
9204
9205    def _test_addr(self, f, t, m, v, alpha=None, beta=None):
9206        dtype = t.dtype
9207        numpy_dtype = dtype
9208        alpha = 1.2 if alpha is None else alpha
9209        beta = 0.8 if beta is None else beta
9210        res1 = f(t, m, v, alpha=alpha, beta=beta)
9211        res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
9212        if beta != 0:
9213            res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
9214        res2 = torch.from_numpy(res2).to(dtype)
9215        self.assertEqual(res1, res2)
9216
9217    def test_addr(self, device="mps", dtype=torch.float32):
9218        M = torch.randn(10, 25, device=device).to(dtype)
9219        m1 = torch.randn(10, device=device).to(dtype)
9220        m2 = torch.randn(25, device=device).to(dtype)
9221        self._test_addr(torch.addr, M, m1, m2)
9222
9223        # Test beta=0, M=nan
9224        M = torch.full((10, 25), math.nan, device=device).to(dtype)
9225        m1 = torch.randn(10, device=device).to(dtype)
9226        m2 = torch.randn(25, device=device).to(dtype)
9227        self._test_addr(torch.addr, M, m1, m2, beta=0)
9228
9229    def test_matrix_rank(self, device="mps", dtype=torch.float32):
9230        matrix_rank = torch.linalg.matrix_rank
9231
9232        def run_test(shape0, shape1, batch):
9233            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
9234            rank_a = matrix_rank(a)
9235
9236            self.assertEqual(rank_a, matrix_rank(a.mH))
9237            aaH = torch.matmul(a, a.mH)
9238            rank_aaH = matrix_rank(aaH)
9239            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
9240            self.assertEqual(rank_aaH, rank_aaH_hermitian)
9241            aHa = torch.matmul(a.mH, a)
9242            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
9243
9244            # check against NumPy
9245            self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
9246            self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
9247
9248            self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
9249            self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
9250
9251            # hermitian flag for NumPy was added in 1.14.0
9252            if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
9253                self.assertEqual(rank_aaH_hermitian,
9254                                 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
9255                self.assertEqual(matrix_rank(aaH, 0.01, True),
9256                                 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
9257
9258            # check out= variant
9259            out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
9260            ans = matrix_rank(a, out=out)
9261            self.assertEqual(ans, out)
9262            self.assertEqual(ans, rank_a)
9263
9264        shapes = (3, 13)
9265        batches = ((), (0, ), (4, ), (3, 5, ))
9266        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
9267            # escape only when NotImplementedError of downstream function is raised
9268            # TODO: remove this once the required function is implemented
9269            try:
9270                run_test(shape0, shape1, batch)
9271            except NotImplementedError as e:
9272                with self.assertRaisesRegex(
9273                        NotImplementedError,
9274                        "The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
9275                    raise e
9276
9277    def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
9278        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
9279
9280        def run_test_main(A, hermitian):
9281            # Testing against definition for pseudo-inverses
9282            A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
9283            np_A = A.cpu().numpy()
9284            np_A_pinv = A_pinv.cpu().numpy()
9285            if A.numel() > 0:
9286                self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
9287                self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
9288                self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9289                self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
9290            else:
9291                self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
9292
9293            # Check out= variant
9294            out = torch.empty_like(A_pinv)
9295            ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
9296            self.assertEqual(ans, out)
9297            self.assertEqual(ans, A_pinv)
9298
9299        def run_test_numpy(A, hermitian):
9300            # Check against NumPy output
9301            # Test float rcond, and specific value for each matrix
9302            rconds = [float(torch.rand(1)), ]
9303            # Test different types of rcond tensor
9304            for rcond_type in MPS_DTYPES:
9305                rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
9306            # Test broadcasting of rcond
9307            if A.ndim > 2:
9308                rconds.append(torch.rand(A.shape[-3], device=device))
9309            for rcond in rconds:
9310                actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
9311                torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
9312                self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
9313                numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
9314                expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
9315                self.assertEqual(actual, expected, atol=precision, rtol=precision)
9316
9317        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
9318                      (3, 2), (5, 3, 2), (2, 5, 3, 2),  # fat matrices
9319                      (2, 3), (5, 2, 3), (2, 5, 2, 3),  # thin matrices
9320                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
9321            A = torch.randn(*sizes, dtype=dtype, device=device)
9322            hermitian = False
9323            run_test_main(A, hermitian)
9324            run_test_numpy(A, hermitian)
9325
9326        # Check hermitian = True
9327        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
9328                      (0, 0), (3, 0, 0), ]:  # zero numel square matrices
9329            A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
9330            hermitian = True
9331            # escape only when NotImplementedError of downstream function is raised
9332            # TODO: remove this once the required function is implemented
9333            try:
9334                run_test_main(A, hermitian)
9335            except NotImplementedError as e:
9336                with self.assertRaisesRegex(
9337                        NotImplementedError,
9338                        "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9339                    raise e
9340            try:
9341                run_test_numpy(A, hermitian)
9342            except NotImplementedError as e:
9343                with self.assertRaisesRegex(
9344                        NotImplementedError,
9345                        "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
9346                    raise e
9347
9348    @parametrize("m", [1, 32, 64])
9349    @parametrize("n", [48, 64])
9350    @parametrize("q_group", [32, 64, 128, 256])
9351    @parametrize("num_groups", [1, 2])
9352    def test__int4_mm(self, m, n, q_group, num_groups):
9353        k = q_group * num_groups
9354        inner_k_tiles = 2
9355
9356        torch.manual_seed(1)
9357        a_f32 = torch.rand((m, k), device="mps")
9358        b_f32 = torch.rand((k, n), device="mps")
9359
9360        def convert_weight_to_int4pack(b):
9361            b_int32, b_scales_and_zeros = _group_quantize_tensor(
9362                b.to("cpu"), n_bit=4, q_group_size=q_group
9363            )
9364            b_int32 = b_int32.to("mps")
9365            b_scales_and_zeros = b_scales_and_zeros.to("mps")
9366            b_int4pack = torch._convert_weight_to_int4pack(
9367                b_int32, inner_k_tiles
9368            )
9369
9370            return b_int4pack, b_scales_and_zeros
9371
9372        def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
9373            return torch._weight_int4pack_mm(
9374                a, b_int4pack, q_group, b_scales_and_zeros
9375            )
9376
9377        b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
9378
9379        for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9380            a = a_f32.to(dtype=dtype)
9381            b = b_f32.to(dtype=dtype)
9382            b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
9383            ref = torch.mm(a, b)
9384            res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
9385
9386            mean_err = ((res - ref).abs() / ref).mean()
9387            self.assertLess(mean_err, 0.05)
9388
9389    @parametrize("m", [1, 32, 64])
9390    @parametrize("k", [32, 64])
9391    @parametrize("n", [32, 64])
9392    def test__int8_mm(self, m, k, n):
9393        torch.manual_seed(1)
9394        a_f32 = torch.rand((m, k), device="mps")
9395        b_f32 = torch.rand((n, k), device="mps")
9396
9397        def convert_weight_to_int8pack(b):
9398            b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
9399                b, -128, 127, torch.int8
9400            )
9401            return b_int8pack, b_scales
9402
9403        def weight_int8pack_mm(a, b_int8pack, b_scales):
9404            return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
9405
9406        b_int8pack, b_scales_f32 = convert_weight_to_int8pack(b_f32)
9407        for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
9408            a = a_f32.to(dtype=dtype)
9409            b = b_f32.to(dtype=dtype)
9410            b_scales = b_scales_f32.to(dtype=dtype)
9411            res = weight_int8pack_mm(a, b_int8pack, b_scales)
9412            ref = torch.mm(a, b.transpose(0, 1))
9413
9414            mean_err = ((res - ref).abs() / ref).mean()
9415            self.assertLess(mean_err, 0.05)
9416
9417
9418class TestSDPA(TestCaseMPS):
9419    def _compare_tensors(self, y, ref):
9420        denom = torch.maximum(ref.abs(), torch.tensor([1e-6], device=ref.device, dtype=ref.dtype))
9421        err = ((y - ref).abs() / denom).mean().item()
9422        self.assertLess(err, 0.01)
9423
9424    def _test_sdpa_no_mask(
9425        self,
9426        is_causal: bool,
9427        dtype: torch.dtype,
9428        L: int = 1,
9429        S: int = 72,
9430        NH: int = 32,
9431        HS: int = 128,
9432        requires_grad: bool = False
9433    ):
9434
9435        torch.manual_seed(1729)
9436        with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
9437            q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps", requires_grad=requires_grad)
9438            k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9439            v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9440            q_cpu = q.cpu().detach().cpu().requires_grad_(requires_grad)
9441            k_cpu = k.cpu()
9442            v_cpu = v.cpu()
9443
9444            y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
9445            y_ref = F.scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, dropout_p=0.0, is_causal=is_causal)
9446
9447            self._compare_tensors(y.cpu(), y_ref)
9448
9449            if requires_grad and torch.is_grad_enabled():
9450                y.sum().backward()
9451                y_ref.sum().backward()
9452
9453                self._compare_tensors(q.grad.cpu(), q_cpu.grad)
9454
9455    def test_sdpa_no_mask_no_causal_fp32(self):
9456        self._test_sdpa_no_mask(False, torch.float32)
9457
9458    def test_sdpa_no_mask_no_causal_fp16(self):
9459        self._test_sdpa_no_mask(False, torch.float16)
9460
9461    def test_sdpa_no_mask_causal_fp32(self):
9462        self._test_sdpa_no_mask(True, torch.float32)
9463
9464    def test_sdpa_no_mask_causal_fp16(self):
9465        self._test_sdpa_no_mask(True, torch.float16)
9466
9467    def test_sdpa_no_mask_causal_fp16_L7(self):
9468        self._test_sdpa_no_mask(True, torch.float16, 7)
9469
9470    def test_sdpa_no_mask_causal_fp16_L7_S17(self):
9471        self._test_sdpa_no_mask(True, torch.float16, 7, 17)
9472
9473    def test_sdpa_no_mask_causal_fp16_L7_S17_NH23_HS121(self):
9474        self._test_sdpa_no_mask(True, torch.float16, 7, 17, 23, 121)
9475
9476    def test_sdpa_no_mask_no_causal_fp32_grad(self):
9477        self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
9478
9479        with torch.no_grad():
9480            self._test_sdpa_no_mask(False, torch.float32, requires_grad=True)
9481
9482    def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = 32, HS: int = 128):
9483        torch.manual_seed(1729)
9484        causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps'))
9485        with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
9486            i = 42
9487
9488            q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps")
9489            k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9490            v = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps")
9491
9492            input_pos = torch.tensor([i], dtype=torch.int32, device='mps')
9493            mask = causal_mask[None, None, input_pos]
9494
9495            y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
9496            y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False)
9497
9498            self._compare_tensors(y.cpu(), y_ref)
9499
9500    def test_sdpa_mask_fp32(self):
9501        self._test_sdpa_mask(torch.float32)
9502
9503    def test_sdpa_mask_fp16(self):
9504        self._test_sdpa_mask(torch.float16)
9505
9506    def test_sdpa_mask_fp16_L6(self):
9507        self._test_sdpa_mask(torch.float16, 6)
9508
9509    def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self):
9510        self._test_sdpa_mask(torch.float16, 7, 17, 23, 121)
9511
9512
9513class TestGatherScatter(TestCaseMPS):
9514    def test_slicing_with_step(self):
9515        # Slicing with step
9516        # https://github.com/pytorch/pytorch/issues/78886
9517        x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
9518        x_mps[::2] = 1.0
9519
9520        x_cpu = torch.zeros(10, dtype=torch.float32, device="cpu")
9521        x_cpu[::2] = 1.0
9522
9523        self.assertEqual(x_cpu, x_mps)
9524
9525    def test_cast_gather_scatter(self):
9526        for _ in range(0, 50):
9527            input = np.random.randint(0, 255, size=(5, 5, 4), dtype=np.uint8)
9528            with torch.no_grad():
9529                s = torch.tensor(input, dtype=torch.uint8, device="mps").unsqueeze(0)
9530                s_cpu = torch.tensor(input, dtype=torch.uint8, device="cpu").unsqueeze(0)
9531                s = s.long()
9532                s_cpu = s_cpu.long()
9533                self.assertEqual(s.cpu(), s_cpu)
9534
9535                s = s.float()
9536                s_cpu = s_cpu.float()
9537                self.assertEqual(s.cpu(), s_cpu)
9538
9539                s /= 255
9540                s_cpu /= 255
9541                self.assertEqual(s.cpu(), s_cpu)
9542
9543    def test_slicing_replace_column(self):
9544        # https://github.com/pytorch/pytorch/issues/78074
9545        def _helper(tensor_data):
9546            x_cpu = torch.tensor(tensor_data)
9547            x_mps = x_cpu.to('mps')
9548
9549            x_cpu[:, 0] = 7
9550            x_mps[:, 0] = 7
9551
9552            self.assertEqual(x_cpu, x_mps)
9553
9554        _helper([[1, 2, 3], [4, 5, 6]])
9555        _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
9556        _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
9557
9558    def test_inplace_scatter(self):
9559        # https://github.com/pytorch/pytorch/issues/79672
9560        a_mps = torch.ones((2, 2),).to(torch.device("mps"))
9561        b_mps = torch.ones((2, 2),).to(torch.device("mps"))
9562
9563        a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9564        b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
9565
9566        a_mps[:, 0] += b_mps[:, 0]
9567        a_cpu[:, 0] += b_cpu[:, 0]
9568        self.assertEqual(a_cpu, a_mps)
9569
9570        a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
9571        a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
9572        self.assertEqual(a_cpu, a_mps)
9573
9574# These tests were taken from test/test_view_ops.py
9575# They are subset of those tests as currently only this subset is working.
9576# This whole `class` will be removed when we add generic device testing. There
9577# are no additional tests added apart from what is part of test_view_ops.py
9578class TestViewOpsMPS(TestCaseMPS):
9579    exact_dtype = True
9580
9581    def test_permute_slicing(self):
9582        # test the fix for crash reported in
9583        # https://github.com/pytorch/pytorch/issues/94190
9584        cpu_x = (torch.randn([3, 2, 2]).float())
9585        mps_x = cpu_x.detach().clone().to('mps')
9586        cpu_out = cpu_x.permute((2, 0, 1)) * 2.0
9587        mps_out = mps_x.permute((2, 0, 1)) * 2.0
9588        # this print caused a crash prior to fix PR#94259
9589        print(torch.zeros_like(mps_out))
9590        # test the fix for fill_scalar_mps() mentioned in issue #94190
9591        self.assertEqual(torch.zeros_like(cpu_out), torch.zeros_like(mps_out))
9592        self.assertEqual(cpu_x[:, 1, :].fill_(1), mps_x[:, 1, :].fill_(1))
9593
9594    def is_view_of(self, base, other):
9595        if (not other._is_view() or
9596                other is base or
9597                other._base is not base or
9598                base.device != other.device):
9599            return False
9600        # Note: only validates storage on native device types
9601        # because some accelerators, like XLA, do not expose storage
9602        if base.device.type == 'mps':
9603            if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
9604                return False
9605
9606        return True
9607
9608    # Returns true if v1 and v2 are views of the same base
9609    def is_view_of_same_base(self, v1, v2):
9610        if (not v1._is_view() or v1 is v2):
9611            return False
9612        return self.is_view_of(v1._base, v2)
9613
9614    # Performs transpose if contiguous=True, else returns the input tensor as is
9615    def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
9616        if contiguous:
9617            return x
9618        else:
9619            return x.transpose(dim0, dim1)
9620
9621    def test_diagonal_view(self, device="mps"):
9622        t = torch.ones((5, 5), device=device)
9623        v = torch.diagonal(t)
9624        self.assertTrue(self.is_view_of(t, v))
9625
9626        v[0] = 0
9627        self.assertEqual(t[0, 0], v[0])
9628
9629        t = torch.ones((3, 3, 3), device="mps")
9630        v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
9631        self.assertTrue(self.is_view_of(t, v))
9632
9633        v[0, 0] = 0
9634        self.assertEqual(t[0, 0, 1], v[0, 0])
9635
9636    def test_select_view(self, device="mps") -> None:
9637        t = torch.ones((5, 5), device=device)
9638        v = t.select(0, 2)
9639        self.assertTrue(self.is_view_of(t, v))
9640
9641        v[0] = 0
9642        self.assertEqual(t[2, 0], v[0])
9643
9644    def test_unbind_view(self, device="mps") -> None:
9645        t = torch.zeros((5, 5), device=device)
9646        tup = torch.unbind(t)
9647
9648        for idx, v in enumerate(tup):
9649            self.assertTrue(self.is_view_of(t, v))
9650
9651            v[0] = idx + 1
9652            self.assertEqual(t[idx, 0], v[0])
9653
9654    def test_expand_view(self, device="mps") -> None:
9655        t = torch.ones((5, 1), device=device)
9656        v = t.expand(5, 5)
9657        self.assertTrue(self.is_view_of(t, v))
9658
9659        v[2, 2] = 0
9660        self.assertEqual(t[2, 0], v[2, 2])
9661
9662    def test_expand_as_view(self, device="mps"):
9663        t = torch.ones((5, 1), device=device)
9664        e = torch.empty((5, 5), device=device)
9665        v = t.expand_as(e)
9666        self.assertTrue(self.is_view_of(t, v))
9667
9668        v[2, 2] = 0
9669        self.assertEqual(t[2, 0], v[2, 2])
9670
9671    def test_narrow_view(self, device="mps"):
9672        t = torch.ones((5, 5), device=device)
9673        v = torch.narrow(t, 1, 2, 2)
9674        self.assertTrue(self.is_view_of(t, v))
9675
9676        v[0, 0] = 0
9677        self.assertEqual(t[0, 2], v[0, 0])
9678
9679    def test_permute_view(self, device="mps") -> None:
9680        t = torch.ones((5, 5), device=device)
9681        v = t.permute(1, 0)
9682        self.assertTrue(self.is_view_of(t, v))
9683
9684        v[0, 1] = 0
9685        self.assertEqual(t[1, 0], v[0, 1])
9686
9687    def test_transpose_view(self, device="mps"):
9688        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
9689            t = torch.ones((5, 5), device=device)
9690            v = fn(t, 0, 1)
9691            self.assertTrue(self.is_view_of(t, v))
9692
9693            v[0, 1] = 0
9694            self.assertEqual(t[1, 0], v[0, 1])
9695
9696    def test_transpose_inplace_view(self, device="mps"):
9697        t = torch.ones(5, 5, device=device)
9698        v = t.view_as(t)
9699        v = v.swapdims_(0, 1)
9700        self.assertTrue(self.is_view_of(t, v))
9701        v[0, 1] = 0
9702        self.assertEqual(t[1, 0], v[0, 1])
9703
9704        t = torch.ones(5, 5, device=device)
9705        v = t.view_as(t)
9706        v = v.swapaxes_(0, 1)
9707        self.assertTrue(self.is_view_of(t, v))
9708        v[0, 1] = 0
9709        self.assertEqual(t[1, 0], v[0, 1])
9710
9711        t = torch.ones(5, 5, device=device)
9712        v = t.view_as(t)
9713        v = v.transpose_(0, 1)
9714        self.assertTrue(self.is_view_of(t, v))
9715        v[0, 1] = 0
9716        self.assertEqual(t[1, 0], v[0, 1])
9717
9718    def test_t_view(self, device="mps"):
9719        t = torch.ones((5, 5), device=device)
9720        v = t.t()
9721        self.assertTrue(self.is_view_of(t, v))
9722
9723        v[0, 1] = 0
9724        self.assertEqual(t[1, 0], v[0, 1])
9725
9726    def test_inplace_view_add(self):
9727        # https://github.com/pytorch/pytorch/issues/96153
9728        t_mps = torch.ones((2, 6,), device='mps')[1].reshape(2, 3)
9729        t_cpu = torch.ones((2, 6,), device='cpu')[1].reshape(2, 3)
9730        t_mps = t_mps + 1
9731        t_cpu = t_cpu + 1
9732        self.assertEqual(t_mps, t_cpu)
9733
9734    def test_t_inplace_view(self, device="mps"):
9735        t = torch.ones(5, 5, device=device)
9736        v = t.view_as(t)
9737        v = v.t_()
9738        self.assertTrue(self.is_view_of(t, v))
9739        v[0, 1] = 0
9740        self.assertEqual(t[1, 0], v[0, 1])
9741
9742    def test_T_view(self, device="mps"):
9743        for op in ("T", "H", "mT", "mH"):
9744            t = torch.ones((5, 5), device=device)
9745            v = getattr(t, op)
9746            self.assertTrue(self.is_view_of(t, v))
9747
9748            v[0, 1] = 0
9749            self.assertEqual(t[1, 0], v[0, 1])
9750
9751    def test_unfold_view(self, device="mps"):
9752        t = torch.ones(10, device=device)
9753        v = t.unfold(0, 3, 2)
9754        self.assertTrue(self.is_view_of(t, v))
9755
9756        v[1, 0] = 0
9757        self.assertEqual(t[2], v[1, 0])
9758
9759    def test_squeeze_view(self, device="mps"):
9760        t = torch.ones(5, 1, 5, device=device)
9761        v = torch.squeeze(t)
9762        self.assertTrue(self.is_view_of(t, v))
9763        v[0, 1] = 0
9764        self.assertIs(t, v._base)
9765
9766    def test_squeeze_inplace_view(self, device="mps"):
9767        t = torch.ones(5, 5, device=device)
9768        v = t.view_as(t)
9769        v = v.squeeze_()
9770        self.assertTrue(self.is_view_of(t, v))
9771        v[0, 1] = 0
9772        self.assertIs(t, v._base)
9773
9774    def test_unsqueeze_view(self, device="mps"):
9775        t = torch.ones(5, 5, device=device)
9776        v = torch.unsqueeze(t, 1)
9777        self.assertTrue(self.is_view_of(t, v))
9778
9779        v[0, 0, 1] = 0
9780        self.assertEqual(t[0, 1], v[0, 0, 1])
9781
9782    def test_unsqueeze_inplace_view(self, device="mps"):
9783        t = torch.ones(5, 5, device=device)
9784        v = t.view_as(t)
9785        v = v.unsqueeze_(1)
9786        self.assertTrue(self.is_view_of(t, v))
9787        v[0, 0, 1] = 0
9788        self.assertEqual(t[0, 1], v[0, 0, 1])
9789
9790    def test_as_strided_view(self, device="mps"):
9791        t = torch.ones(5, 5, device=device)
9792        v = torch.as_strided(t, (25,), (1,))
9793        self.assertTrue(self.is_view_of(t, v))
9794
9795        v[6] = 0
9796        self.assertEqual(t[1, 1], v[6])
9797
9798    def test_as_strided_inplace_view(self, device="mps"):
9799        t = torch.ones(5, 5, device=device)
9800        v = t.view_as(t)
9801        v = v.as_strided_((25,), (1,))
9802        self.assertTrue(self.is_view_of(t, v))
9803        v[6] = 0
9804        self.assertEqual(t[1, 1], v[6])
9805
9806    def test_view_view(self, device="mps"):
9807        t = torch.ones(5, 5, device=device)
9808        v = t.view(25)
9809        self.assertTrue(self.is_view_of(t, v))
9810
9811        v[6] = 0
9812        self.assertEqual(t[1, 1], v[6])
9813
9814    def test_view_as_view(self, device="mps"):
9815        t = torch.ones(5, 5, device=device)
9816        e = torch.empty((25,))
9817        v = t.view_as(e)
9818        self.assertTrue(self.is_view_of(t, v))
9819
9820        v[6] = 0
9821        self.assertEqual(t[1, 1], v[6])
9822
9823    def test_contiguous_self(self, device="mps"):
9824        t = torch.ones(5, 5, device=device)
9825        s = t.contiguous()
9826        self.assertIs(s, t)
9827
9828    def test_contiguous_nonview(self, device="mps"):
9829        t = torch.ones(5, 5, device=device)
9830        nv = t.t().contiguous()
9831        self.assertFalse(self.is_view_of(t, nv))
9832
9833        nv[0, 0] = 0
9834        self.assertNotEqual(t[0, 0], nv[0, 0])
9835
9836    def test_reshape_view(self, device="mps"):
9837        t = torch.ones(5, 5, device=device)
9838        v = torch.reshape(t, (25,))
9839        self.assertTrue(self.is_view_of(t, v))
9840
9841        v[6] = 0
9842        self.assertEqual(t[1, 1], v[6])
9843
9844    def test_reshape_as_view(self, device="mps"):
9845        t = torch.ones(5, 5, device=device)
9846        e = torch.empty((25,), device=device)
9847        v = t.reshape_as(e)
9848        self.assertTrue(self.is_view_of(t, v))
9849
9850        v[6] = 0
9851        self.assertEqual(t[1, 1], v[6])
9852
9853    def test_reshape_nonview(self, device="mps"):
9854        t = torch.ones(5, 5, device=device)
9855        nv = torch.reshape(t.t(), (25,))
9856        self.assertFalse(self.is_view_of(t, nv))
9857
9858        nv[6] = 0
9859        self.assertNotEqual(t[1, 1], nv[6])
9860
9861    def test_flatten_view(self, device="mps"):
9862        def test_writes_propagate(t, v):
9863            idx_t = (0,) * t.ndim
9864            idx_v = (0,) * v.ndim
9865            v[idx_v] = 0
9866            self.assertEqual(t[idx_t], v[idx_v])
9867
9868        t = torch.ones(1, 2, 3, 4, device=device)
9869        v = t.flatten()
9870        self.assertTrue(self.is_view_of(t, v))
9871        test_writes_propagate(t, v)
9872
9873        # zero-dimensional tensor
9874        t = torch.tensor(1, device=device)
9875        v = t.flatten()
9876        test_writes_propagate(t, v)
9877        self.assertTrue(self.is_view_of(t, v))
9878
9879        t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
9880        v = t.flatten(0, 1)
9881        test_writes_propagate(t, v)
9882        self.assertTrue(self.is_view_of_same_base(t, v))
9883
9884        # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
9885        t = torch.ones(720, device=device) \
9886            .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
9887        #               [--1--|---2---|-3-] [--1--|----2---|-3-]
9888        v1 = t.flatten(0, 1)
9889        v2 = v1.flatten(1, 3)
9890        v3 = v2.flatten(2, 2)
9891        test_writes_propagate(t, v1)
9892        self.assertTrue(self.is_view_of_same_base(t, v1))
9893        test_writes_propagate(t, v2)
9894        self.assertTrue(self.is_view_of_same_base(t, v2))
9895        test_writes_propagate(t, v3)
9896        self.assertTrue(self.is_view_of_same_base(t, v3))
9897
9898    def test_flatten_nonview(self, device="mps"):
9899        def assert_is_nonview(t, nv):
9900            idx_t = (0,) * t.ndim
9901            idx_nv = (0,) * nv.ndim
9902            self.assertFalse(nv._is_view())
9903            nv[idx_nv] = 0
9904            self.assertNotEqual(t[idx_t], nv[idx_nv])
9905        t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
9906        nv = t.flatten(1, 3)
9907        assert_is_nonview(t, nv)
9908
9909        t = torch.ones(2, 2, device=device).T
9910        nv = t.flatten()
9911        assert_is_nonview(t, nv)
9912
9913        # flatten returns the original object if start_dim=end_dim
9914        t = t = torch.ones(2, 2, device=device)
9915        nv = t.flatten(1, 1)
9916        self.assertIs(t, nv)
9917
9918    def test_basic_indexing_slice_view(self, device="mps"):
9919        t = torch.ones(5, 5, device=device)
9920        v = t[:2, :3]
9921        self.assertTrue(self.is_view_of(t, v))
9922
9923        v[0, 0] = 0
9924        self.assertEqual(t[0, 0], v[0, 0])
9925
9926    def test_basic_indexing_ellipses_view(self, device="mps"):
9927        t = torch.ones(5, 5, device=device)
9928        v = t[..., :2]
9929        self.assertTrue(self.is_view_of(t, v))
9930
9931        v[0, 0] = 0
9932        self.assertEqual(t[0, 0], v[0, 0])
9933
9934    def test_basic_indexing_newaxis_view(self, device="mps"):
9935        t = torch.ones(5, 5, device=device)
9936        v = t[None, :2, 3]
9937        self.assertTrue(self.is_view_of(t, v))
9938
9939        v[0, 0] = 0
9940        self.assertEqual(t[0, 3], v[0, 0])
9941
9942    def test_chunk_view(self, device="mps"):
9943        t = torch.zeros(3, 3, device=device)
9944        l = torch.chunk(t, 3)
9945
9946        for idx, v in enumerate(l):
9947            self.assertTrue(self.is_view_of(t, v))
9948
9949            v[0, 0] = idx + 1
9950            self.assertEqual(t[idx, 0], v[0, 0])
9951
9952    def test_split_view(self, device="mps"):
9953        t = torch.zeros(3, 3, device=device)
9954        l = torch.split(t, [1, 1, 1])
9955
9956        for idx, v in enumerate(l):
9957            self.assertTrue(self.is_view_of(t, v))
9958
9959            v[0, 0] = idx + 1
9960            self.assertEqual(t[idx, 0], v[0, 0])
9961
9962    def test_movedim_view(self, device="mps"):
9963        def run_test(device, op):
9964            t = torch.zeros(3, 3, device=device)
9965            out = op(t)
9966
9967            self.assertTrue(self.is_view_of(t, out))
9968
9969            # Randomly change values in output
9970            # and verify that original is changed
9971            # as well.
9972            for _ in range(3):
9973                idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
9974                out[idx_1, idx_2] = random.random()
9975                self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
9976
9977        for fn in [torch.movedim, torch.moveaxis]:
9978            op = partial(fn, source=(0, 1), destination=(1, 0))
9979            run_test(device, op)
9980
9981            op = partial(fn, source=0, destination=1)
9982            run_test(device, op)
9983
9984    # Testing that the generated view_copy kernel and its derivative are implemented correctly
9985    def test_view_copy(self, device="mps"):
9986        a = torch.randn(4, device=device, requires_grad=True)
9987        a_ref = a.clone().detach().requires_grad_()
9988        a_view = a_ref.view(2, 2)
9989        a_view_copy = torch.view_copy(a, (2, 2))
9990
9991        # view_copy ops don't preserve view relationship
9992        self.assertTrue(self.is_view_of(a_ref, a_view))
9993        self.assertFalse(self.is_view_of(a, a_view_copy))
9994
9995        a_view_copy.sum().backward()
9996        a_view.sum().backward()
9997
9998        # forward and backward give the same shape + result
9999        self.assertEqual(a_view_copy, a_view)
10000        self.assertEqual(a.grad, a_ref.grad)
10001
10002    def test_view_copy_out(self, device="mps"):
10003        a = torch.randn(2, 2, device=device)
10004        out = torch.empty(2, device=device)
10005
10006        torch.diagonal_copy(a, out=out)
10007        expected = torch.diagonal_copy(a)
10008
10009        self.assertEqual(expected, out)
10010
10011        a = torch.randn(4, device=device)
10012        out1 = torch.empty(2, device=device)
10013        out2 = torch.empty(2, device=device)
10014
10015        torch.split_copy(a, 2, out=(out1, out2))
10016        expected1, expected2 = torch.split_copy(a, 2)
10017
10018        self.assertEqual(expected1, out1)
10019        self.assertEqual(expected2, out2)
10020
10021    def test_detached_view_copy(self, device="mps"):
10022        # https://github.com/pytorch/pytorch/issues/86052
10023        x = torch.arange(2)
10024        # .detach() makes y not a view, but contig tensor
10025        # with non-zero offset
10026        y = x[1].detach()
10027        z = y.to(device)
10028        self.assertEqual(y, z.cpu())
10029
10030    def test_empty_reshape(self, device="mps"):
10031        x = torch.randn(0, 6, device=device)
10032        self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
10033        # should be viewable -- i.e. data_ptr is the same.
10034        self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
10035
10036        # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
10037        self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
10038
10039    def test_expand(self, device="mps"):
10040        tensor = torch.rand(1, 8, 1, device=device)
10041        tensor2 = torch.rand(5, device=device)
10042        template = torch.rand(4, 8, 5, device=device)
10043        target = template.size()
10044        self.assertEqual(tensor.expand_as(template).size(), target)
10045        self.assertEqual(tensor.expand(4, 8, 5).size(), target)
10046        self.assertEqual(tensor.expand(target).size(), target)
10047        self.assertEqual(tensor2.expand_as(template).size(), target)
10048        self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
10049        self.assertEqual(tensor2.expand(target).size(), target)
10050
10051        # test double expand
10052        self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
10053
10054        # test non-contiguous
10055        noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
10056        self.assertFalse(noncontig.is_contiguous())
10057        self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
10058
10059        # make sure it's compatible with unsqueeze
10060        expanded = tensor2.expand(1, 1, 5)
10061        unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
10062        self.assertEqual(expanded, unsqueezed)
10063        self.assertEqual(expanded.stride(), unsqueezed.stride())
10064
10065        # test -1 as target size
10066        self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
10067        self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
10068
10069        # test expanding empty to empty
10070        self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
10071
10072    def test_view_empty(self, device="mps"):
10073        x = torch.randn(0, 6, device=device)
10074        self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
10075
10076    def test_reshape(self, device="mps"):
10077        x = torch.randn(3, 3, device=device)
10078        self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
10079        self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
10080        self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
10081        self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
10082
10083        y = torch.randn(4, 4, 4, device=device)[:, 0, :]
10084        # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
10085        if device != "meta":
10086            self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
10087        self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
10088        self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
10089
10090        s = torch.randn((), device=device)
10091        self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
10092        self.assertEqual(s.reshape(-1).shape, (1,))
10093        self.assertRaises(RuntimeError, lambda: s.reshape(2))
10094
10095        empty = torch.tensor([], device=device)
10096        self.assertEqual(empty, empty.reshape(-1))
10097        self.assertEqual(empty, empty.reshape([0]))
10098        # TODO: fix these once we have multi-dimensional empty tensors
10099        self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
10100        self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
10101        self.assertRaises(RuntimeError, lambda: empty.reshape(1))
10102
10103        x = torch.randn(3, 3, device=device)
10104        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
10105        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
10106        self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
10107
10108    def test_narrow(self, device="mps"):
10109        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
10110        self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
10111        self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
10112        self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
10113        self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
10114        self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
10115        self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
10116        self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
10117        self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
10118
10119    def test_narrow_tensor(self, device="mps"):
10120        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
10121        self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
10122        with self.assertRaises(Exception):
10123            x.narrow(0, torch.tensor(0.), 1)
10124        with self.assertRaises(Exception):
10125            x.narrow(0, torch.tensor([0]), 1)
10126        with self.assertRaises(Exception):
10127            x.narrow(0, torch.tensor([0, 1]), 1)
10128
10129    def test_t(self, device="mps"):
10130        # Test 0D tensors
10131        x = torch.randn(())
10132        self.assertEqual(x, x.t())
10133        x = x.to_sparse()
10134        self.assertEqual(x, x.t())
10135
10136        # Test 1D tensors
10137        x = torch.arange(4)
10138        self.assertEqual(x, x.t())
10139        x = x.to_sparse()
10140        self.assertEqual(x, x.t())
10141
10142        # Test 2D tensors
10143        x = torch.rand((2, 2))
10144        self.assertEqual(x.t(), x.transpose(0, 1))
10145        x = x.to_sparse()
10146        self.assertEqual(x.t(), x.transpose(0, 1))
10147
10148        # Test 3D tensor
10149        x = torch.rand((2, 2, 2))
10150        with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
10151            x.t()
10152        x = x.to_sparse()
10153        with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
10154            x.t()
10155
10156    def test_split(self, device="mps"):
10157        tensor = torch.rand(7, 4)
10158        split_size = 3
10159        dim = 0
10160        target_sizes = ([3, 4], [3, 4], [1, 4])
10161        splits = tensor.split(split_size, dim)
10162        start = 0
10163        for target_size, split in zip(target_sizes, splits):
10164            self.assertEqual(split.size(), target_size)
10165            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10166            start = start + target_size[dim]
10167
10168        # Variable sections split
10169        tensor = torch.randn(20, 10)
10170        dim = 0
10171        split_sizes = [5, 5, 10]
10172        target_sizes = ([[5, 10], [5, 10], [10, 10]])
10173        splits = tensor.split(split_sizes, dim)
10174        start = 0
10175        for target_size, split in zip(target_sizes, splits):
10176            self.assertEqual(split.size(), target_size)
10177            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10178            start = start + target_size[dim]
10179
10180        split_sizes = [2, 2, 6]
10181        target_sizes = ([20, 2], [20, 2], [20, 6])
10182        dim = 1
10183        splits = tensor.split(split_sizes, dim)
10184        start = 0
10185        for target_size, split in zip(target_sizes, splits):
10186            self.assertEqual(split.size(), target_size)
10187            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
10188            start = start + target_size[dim]
10189
10190    def test_chunk(self, device="mps"):
10191        tensor = torch.rand(4, 7)
10192        num_chunks = 3
10193        dim = 1
10194        target_sizes = ([4, 3], [4, 3], [4, 1])
10195        splits = tensor.chunk(num_chunks, dim)
10196        start = 0
10197        for target_size, split in zip(target_sizes, splits):
10198            self.assertEqual(split.size(), target_size)
10199            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
10200                             atol=0, rtol=0)
10201            start = start + target_size[dim]
10202
10203        # Invalid chunk sizes
10204        error_regex = 'chunk expects.*greater than 0'
10205        with self.assertRaisesRegex(RuntimeError, error_regex):
10206            tensor.chunk(0)
10207        with self.assertRaisesRegex(RuntimeError, error_regex):
10208            tensor.chunk(-2)
10209
10210    def test_unsqueeze(self, device="mps") -> None:
10211        x = torch.randn(2, 3, 4)
10212        y = x.unsqueeze(1)
10213        self.assertEqual(y, x.view(2, 1, 3, 4))
10214        y = x.clone().unsqueeze_(2)
10215        self.assertEqual(y, x.view(2, 3, 1, 4))
10216
10217        x = x[:, 1]
10218        self.assertFalse(x.is_contiguous())
10219        y = x.unsqueeze(1)
10220        self.assertEqual(y, x.contiguous().view(2, 1, 4))
10221        y = x.clone().unsqueeze_(2)
10222        self.assertEqual(y, x.contiguous().view(2, 4, 1))
10223
10224    # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
10225    def test_big_transpose(self, device="mps"):
10226        t = torch.rand(456, 789, device=device)
10227        t1 = t.t().contiguous()
10228        t2 = torch.from_numpy(t.cpu().numpy().transpose())
10229        self.assertEqual(t1, t2)
10230
10231    def test_T(self, device="mps"):
10232        a = torch.randn(2, 3, 4, device=device)
10233        t1 = a.T
10234        t2 = a.permute(2, 1, 0)
10235        self.assertEqual(t2, t1)
10236        b = torch.randn(10, device=device)
10237        self.assertEqual(b, b.T)
10238
10239    def test_transposes(self, device="mps", dtype=torch.float32):
10240        for op in ("T", "H", "mT", "mH", "adjoint"):
10241            shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
10242            for shape in shapes:
10243                a = make_tensor(shape, device=device, dtype=dtype)
10244                t1 = getattr(a, op)
10245                if op == "adjoint":
10246                    t1 = t1()
10247                t2 = a
10248                if a.ndim != 0:
10249                    t2 = t2.transpose(-2, -1)
10250                if op[-1] == "H" or op == "adjoint":
10251                    t2 = t2.conj()
10252                self.assertEqual(t2, t1)
10253
10254    def test_transposes_errors(self, device="mps", dtype=torch.float32):
10255        for op in ("H", "mT", "mH", "adjoint"):
10256            shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
10257            for shape in shapes:
10258                a = make_tensor(shape, device=device, dtype=dtype)
10259                with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
10260                    t1 = getattr(a, op)
10261                    if op == "adjoint":
10262                        t1 = t1()
10263
10264    def test_python_types(self, device="mps"):
10265        a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
10266        a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
10267        self.assertEqual(a1.dtype, a2.dtype)
10268
10269        b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
10270        b2 = torch.arange(10, 20, dtype=int, device=device)
10271        self.assertEqual(b1.dtype, b2.dtype)
10272
10273        c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
10274        c2 = torch.tensor([True, False], dtype=bool, device=device)
10275        self.assertEqual(c1.dtype, c2.dtype)
10276
10277    # TODO: is resize best put in test_view_ops?
10278    def test_resize_as_preserves_strides(self, device="mps"):
10279        x = torch.empty(2, 3).t()
10280        old_strides = x.stride()
10281        x.resize_as_(x)
10282        self.assertEqual(x.stride(), old_strides)
10283
10284    def test_memory_format_resize_as(self, device="mps"):
10285        def test_helper(shape, memory_format, device="mps"):
10286            xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
10287            flat = torch.randn(xc.numel(), device=device)
10288            flat.resize_as_(xc, memory_format=torch.preserve_format)
10289            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10290
10291        test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
10292        test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
10293
10294    def test_memory_format_resize_(self, device="mps"):
10295        def test_helper(shape, numel, memory_format, device="mps"):
10296            flat = torch.randn(numel, device=device)
10297            flat.resize_(shape, memory_format=memory_format)
10298            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
10299
10300        test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
10301        test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
10302
10303    # TODO: OpInfo this
10304    def _test_atleast(self, device, torch_fn):
10305        # 0-dim
10306        s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
10307
10308        gradcheck(lambda x: torch_fn(x), s)
10309        gradgradcheck(lambda x: torch_fn(x), s)
10310
10311        # 1-dim
10312        a = torch.rand(4, dtype=torch.double, requires_grad=True)
10313
10314        gradcheck(lambda x: torch_fn(x), a)
10315        gradgradcheck(lambda x: torch_fn(x), a)
10316
10317        # 2,3,4-dim
10318        b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
10319        c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
10320        d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
10321
10322        input_tuple = (s, a, b, c, d)
10323        gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10324        gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
10325
10326    def test_atleast_gradient(self, device="mps"):
10327        self._test_atleast(device, torch.atleast_1d)
10328        self._test_atleast(device, torch.atleast_2d)
10329        self._test_atleast(device, torch.atleast_3d)
10330
10331    def test_view(self, device="mps"):
10332        tensor = torch.rand(15, device=device)
10333        template = torch.rand(3, 5, device=device)
10334        empty = torch.empty(0, device=device)
10335        target = template.size()
10336        self.assertEqual(tensor.view_as(template).size(), target)
10337        self.assertEqual(tensor.view(3, 5).size(), target)
10338        self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
10339        self.assertEqual(tensor.view(-1, 5).size(), target)
10340        self.assertEqual(tensor.view(3, -1).size(), target)
10341        tensor_view = tensor.view(5, 3)
10342        tensor_view.fill_(random.uniform(0, 1))
10343        self.assertEqual(empty.view_as(empty), empty)
10344        self.assertEqual(empty.view(0), empty)
10345        self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
10346        self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
10347
10348        # test size inference with empty tensors
10349        self.assertEqual(empty.view(-1).size(), torch.Size([0]))
10350        self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
10351
10352        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10353            empty.view(-1, 0)
10354
10355        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
10356            empty.view(3, 0, -1, 0)
10357
10358        self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
10359        self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
10360        self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
10361
10362    def test_contiguous(self, device="mps"):
10363        x = torch.randn(1, 16, 5, 5, device=device)
10364        self.assertTrue(x.is_contiguous())
10365        stride = list(x.stride())
10366        stride[0] = 20
10367        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
10368        x.set_(x.storage(), 0, x.size(), stride)
10369        self.assertTrue(x.is_contiguous())
10370
10371    def test_resize_mps_dtypes(self, device="mps"):
10372        shape = (2, 2)
10373        for dt in MPS_DTYPES:
10374            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10375            x.resize_(shape)
10376            self.assertEqual(shape, x.shape)
10377
10378    def test_resize_as_mps_dtypes(self, device="mps"):
10379        for dt in MPS_DTYPES:
10380            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10381            y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
10382            x.resize_as_(y)
10383            self.assertEqual(y.shape, x.shape)
10384
10385    def test_resize_overflow(self, device="mps"):
10386        x = torch.empty((), dtype=torch.float64)
10387        with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
10388            x.resize_([2, 4, 2**29, 2**29])
10389        with self.assertRaisesRegex(RuntimeError, 'overflow'):
10390            x.resize_([8, 8, 2**29, 2**29])
10391
10392    def test_view_all_dtypes_and_devices(self, device="mps"):
10393        for dt in (torch.float, torch.bool):
10394            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
10395            self.assertEqual(x.view(6).shape, [6])
10396
10397class TestConvolutionMPS(TestCaseMPS):
10398    def test_conv1d_all_strides_paddings(self):
10399        # https://github.com/pytorch/pytorch/issues/82921
10400        def helper(stride, padding):
10401            y_cpu = torch.randn(1, 57, 40)
10402            conv_cpu = nn.Conv1d(57, 20, stride=stride, padding=padding, kernel_size=3, bias=False)
10403            conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10404            x_cpu = conv_cpu(y_cpu)
10405
10406            y_gpu = y_cpu.to(device='mps')
10407            x_gpu = conv_gpu(y_gpu)
10408            self.assertEqual(x_cpu, x_gpu.cpu())
10409        for stride in range(1, 4):
10410            for padding in range(1, 4):
10411                helper(stride, padding)
10412
10413
10414    def test_conv1d_channels_last(self):
10415        # https://github.com/pytorch/pytorch/issues/81557
10416        model_cpu = torch.nn.Conv1d(1, 128, 3)
10417        a_cpu = torch.arange((128 * 176), dtype=torch.float32)
10418        a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
10419        out_cpu = model_cpu(a_cpu)
10420
10421        a_mps = a_cpu.detach().clone().to("mps")
10422        model_mps = model_cpu.to("mps")
10423        out_mps = model_mps(a_mps)
10424
10425        self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
10426
10427    def test_conv_transpose_1d_all_strides(self):
10428        # https://github.com/pytorch/pytorch/issues/82711
10429        def helper(stride):
10430            y_cpu = torch.ones(1, 1, 2)
10431            deconv_cpu = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=1, stride=stride, bias=False, padding=1)
10432            deconv_cpu.weight.data = torch.ones(1, 1, 2)
10433            deconv_gpu = copy.deepcopy(deconv_cpu).to(device='mps')
10434            x_cpu = deconv_cpu(y_cpu)
10435
10436            y_gpu = y_cpu.to(device='mps')
10437            x_gpu = deconv_gpu(y_gpu)
10438            self.assertEqual(x_cpu, x_gpu.cpu())
10439        [helper(stride) for stride in [1, 2, 3]]
10440
10441    def test_conv_transpose_1d_nn_functional(self):
10442        # https://github.com/pytorch/pytorch/issues/82563
10443        tin = torch.rand((1, 512, 1245), dtype=torch.float32)
10444        tparams = torch.rand((512, 256, 16), dtype=torch.float32)
10445        tbias = torch.rand((256), dtype=torch.float32)
10446
10447        device = 'cpu'
10448        tcpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10449
10450        device = 'mps'
10451        tgpu = torch.nn.functional.conv_transpose1d(tin.to(device), tparams.to(device), tbias.to(device), stride=8, padding=4)
10452
10453        self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
10454
10455    def test_conv_backward_1d_channels_last(self):
10456        def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
10457            # https://github.com/pytorch/pytorch/issues/84511
10458            conv_cpu = torch.nn.Conv1d(
10459                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
10460            conv_mps = torch.nn.Conv1d(
10461                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
10462            conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
10463            conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
10464
10465
10466            data = torch.rand(shape, dtype=torch.float32)
10467            x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
10468            x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
10469            res_cpu = conv_cpu(x_cpu)
10470            res_mps = conv_mps(x_mps)
10471            self.assertEqual(res_cpu, res_mps)
10472            res_cpu = res_cpu.sum().backward()
10473            res_mps = res_mps.sum().backward()
10474
10475            self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10476            self.assertEqual(x_cpu.grad, x_mps.grad)
10477
10478        helper(shape=(1, 176, 1))
10479        helper(shape=(2, 12, 1))
10480        helper(shape=(3, 176, 1))
10481        helper(shape=(4, 376, 1))
10482        helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
10483        helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
10484
10485    def test_conv1d_contiguous(self):
10486        model_cpu = torch.nn.Conv1d(1, 128, 3)
10487        a_cpu = torch.ones(128, 1, 176)
10488        out_cpu = model_cpu(a_cpu)
10489
10490        a_mps = a_cpu.detach().clone().to("mps")
10491        model_mps = model_cpu.to("mps")
10492        out_mps = model_mps(a_mps)
10493
10494        self.assertEqual(out_cpu.shape, out_mps.shape)
10495        self.assertEqual(out_cpu, out_mps.cpu())
10496
10497    def test_conv2d_all_strides_paddings(self):
10498        # https://github.com/pytorch/pytorch/issues/83180
10499        def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
10500            x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
10501            x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
10502
10503            if permute_data:
10504                x_cpu.permute(0, 2, 3, 1)
10505                x_mps.permute(0, 2, 3, 1)
10506
10507            for strideX in range(1, 4):
10508                for strideY in range(1, 4):
10509                    conv_cpu = torch.nn.Conv2d(
10510                        in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
10511                    conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
10512
10513                    conv_mps = torch.nn.Conv2d(
10514                        in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
10515                    conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10516                    conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10517
10518                    res_cpu = conv_cpu(x_cpu)
10519                    res_mps = conv_mps(x_mps)
10520                    self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
10521                    res_cpu = res_cpu.sum().backward()
10522                    res_mps = res_mps.sum().backward()
10523                    self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
10524
10525                    self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
10526                    self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
10527                    self.assertEqual(x_cpu.grad, x_mps.grad)
10528
10529        for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10530            for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
10531                for permute_data in [True, False]:
10532                    helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
10533                    helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10534                    helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
10535
10536    def test_conv_transpose_2d_strided(self):
10537        def helper(m_cpu, memory_format):
10538            m_mps = copy.deepcopy(m_cpu).requires_grad_()
10539            m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10540            m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10541
10542            input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
10543            input_mps = input_cpu.detach().clone().to("mps")
10544
10545            output_cpu = m_cpu(input_cpu)
10546            output_mps = m_mps(input_mps)
10547            self.assertEqual(output_cpu, output_mps)
10548
10549        for mem_format_input in [torch.contiguous_format, torch.channels_last]:
10550            # With square kernels and equal stride
10551            helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
10552
10553            # non-square kernels and unequal stride and with padding
10554            helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
10555
10556    def test_conv_transpose_2d_specified_output(self):
10557        input_cpu = torch.randn(1, 16, 12, 12)
10558        input_mps = input_cpu.detach().clone().to("mps")
10559
10560        downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
10561        downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
10562        downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10563        downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10564
10565        upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
10566        upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
10567        upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
10568        upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
10569
10570        h_cpu = downsample_cpu(input_cpu)
10571        h_mps = downsample_mps(input_mps)
10572        self.assertEqual(h_cpu, h_mps)
10573
10574        size_cpu = h_cpu.size()
10575        size_mps = h_mps.size()
10576        self.assertEqual(size_cpu, size_mps)
10577
10578        output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
10579        output_mps = upsample_mps(h_mps, output_size=input_mps.size())
10580        self.assertEqual(output_cpu, output_mps)
10581        self.assertEqual(output_cpu.size(), output_mps.size())
10582
10583    def test_conv2d_single_stride(self):
10584        y_cpu = torch.randn(2, 2, 3, 6)
10585        y_gpu = y_cpu.to(device='mps')
10586        for stride in range(1, 4):
10587            conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=stride)
10588            conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10589            x_cpu = conv_cpu(y_cpu)
10590            x_gpu = conv_gpu(y_gpu)
10591            self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10592
10593    @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
10594    def test_conv3d_single_stride(self):
10595        # Conv3d is only available from MacOS 13.2 onwards
10596        y_cpu = torch.randn(2, 2, 3, 6)
10597        y_gpu = y_cpu.to(device='mps')
10598        for stride in range(1, 4):
10599            conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
10600            conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
10601            x_cpu = conv_cpu(y_cpu)
10602            x_gpu = conv_gpu(y_gpu)
10603            self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
10604
10605    def test_grid_sample(self):
10606        def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
10607            def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):
10608                for grid_dim_contig_order in [(0, 1, 2, 3), (0, 3, 1, 2), (3, 0, 1, 2), (0, 2, 1, 3)]:
10609                    # grid_dim_contig_order specifies the dimension order that can
10610                    # make grid to be contiguous.
10611                    # i.e., grid.permute(grid_dim_contig_order) is contiguous.
10612                    # e.g., with grid_dim_contig_order=[0, 3, 1, 2], grid should be
10613                    #       initialized with contiguous tensor of shape [N, 2, H, W]
10614                    #       and permuted to [N, H, W, 2] afterwards.
10615                    grid_shape = [N, H, W, 2]
10616                    grid_init_shape = [grid_shape[d] for d in grid_dim_contig_order]
10617                    grid_fwd_permute = [None, None, None, None]
10618                    for i, d in enumerate(grid_dim_contig_order):
10619                        grid_fwd_permute[d] = i
10620
10621                    def get_grid(device='cpu', data=None):
10622                        if data is not None:
10623                            assert list(data.shape) == grid_shape
10624                            data = data.permute(grid_dim_contig_order).to(device)
10625                        else:
10626                            data = torch.randn(grid_init_shape, device=device)
10627                        grid = data.permute(grid_fwd_permute)
10628                        assert grid.permute(grid_dim_contig_order).is_contiguous()
10629                        return grid
10630
10631                    input_cpu = torch.randn(C, N, IH, IW).transpose(0, 1).requires_grad_(input_requires_grad)
10632                    grid_cpu = get_grid().requires_grad_()
10633                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10634                                            align_corners=align_corners)
10635                    self.assertEqual(out_cpu.size(), torch.Size([N, C, H, W]))
10636
10637                    gradients = torch.randn_like(out_cpu)
10638                    out_cpu.backward(gradients)
10639
10640
10641                    # Compare against unvectorized CPU fallback
10642
10643                    # NOTE [ grid_sample CPU fallback ]
10644                    # grid_sample uses AVX for 2d images, but that requires 32-bit indexing for
10645                    # 32-bit floats. So we also have a fallback that is used only for float tensors
10646                    # requiring 64-bit indexing. That requires too much memory to run on CI, so we
10647                    # also export the fallback and test it here to ensure feature parity with
10648                    # the vectorized version.
10649                    input_fallback = input_cpu.float().detach_().requires_grad_()
10650                    grid_fallback = grid_cpu.float().detach_().requires_grad_()
10651                    out_fallback = torch._grid_sampler_2d_cpu_fallback(
10652                        input_fallback, grid_fallback,
10653                        F.GRID_SAMPLE_INTERPOLATION_MODES[mode],
10654                        F.GRID_SAMPLE_PADDING_MODES[padding_mode],
10655                        align_corners)
10656                    self.assertEqual(out_fallback, out_cpu.float(), atol=1e-5, rtol=5e-5)
10657
10658                    out_fallback.backward(gradients.float())
10659                    if input_requires_grad:
10660                        self.assertEqual(input_fallback.grad, input_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10661                    self.assertEqual(grid_fallback.grad, grid_cpu.grad.float(), atol=1e-4, rtol=5e-5)
10662
10663                    input_mps = input_cpu.detach().transpose(0, 1).to("mps").transpose(0, 1).requires_grad_(input_requires_grad)
10664                    grid_mps = get_grid('mps', grid_cpu.detach()).requires_grad_()
10665                    out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10666                    self.assertEqual(out_cpu, out_mps)
10667                    out_mps.backward(gradients.to("mps"))
10668                    if input_requires_grad:
10669                        self.assertEqual(input_cpu.grad, input_mps.grad)
10670                    self.assertEqual(grid_cpu.grad, grid_mps.grad, atol=5e-5, rtol=0)
10671
10672                    # check that zero-dimensional input strides don't error out
10673                    base_input = torch.randn(N, C, 1, IW)
10674                    input_cpu = base_input.expand_as(input_mps).requires_grad_(input_requires_grad)
10675                    out_cpu = F.grid_sample(input_cpu, grid_cpu, mode=mode, padding_mode=padding_mode,
10676                                            align_corners=align_corners)
10677
10678                    input_mps = base_input.to("mps").expand_as(input_mps).requires_grad_(input_requires_grad)
10679                    out_mps = F.grid_sample(input_mps, grid_mps, mode=mode, padding_mode=padding_mode, align_corners=align_corners)
10680                    self.assertEqual(out_cpu, out_mps)
10681
10682            # test same size output
10683            test_shape(N, C, H, W, H, W, mode, padding_mode, align_corners)
10684
10685            # test larger output
10686            N = random.randint(2, 8)
10687            C = random.randint(2, 8)
10688            IH = random.randint(2, 8)
10689            IW = random.randint(2, 8)
10690            H = random.randint(IH + 1, 12)
10691            W = random.randint(IW + 1, 12)
10692            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10693
10694            # test smaller output
10695            N = random.randint(2, 8)
10696            C = random.randint(2, 8)
10697            IH = random.randint(2, 8)
10698            IW = random.randint(2, 8)
10699            H = random.randint(2, IH)
10700            W = random.randint(2, IW)
10701            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10702
10703            # test 1x1 inpput
10704            N = random.randint(2, 8)
10705            C = random.randint(2, 8)
10706            IH = 1
10707            IW = 1
10708            H = random.randint(2, 5)
10709            W = random.randint(2, 5)
10710            test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners)
10711
10712            # testing empty grid
10713            N = random.randint(2, 8)
10714            C = random.randint(2, 8)
10715            IH = random.randint(2, 8)
10716            IW = random.randint(2, 8)
10717            W = random.randint(3, IW + 2)
10718            test_shape(N, C, IH, IW, 0, W, mode, padding_mode, align_corners)
10719
10720            # testing empty channel
10721            N = random.randint(2, 8)
10722            IH = random.randint(2, 8)
10723            IW = random.randint(2, 8)
10724            H = random.randint(3, IH + 2)
10725            W = random.randint(3, IW + 2)
10726            test_shape(N, 0, IH, IW, H, W, mode, padding_mode, align_corners)
10727
10728            # testing empty batch
10729            C = random.randint(2, 8)
10730            IH = random.randint(2, 8)
10731            IW = random.randint(2, 8)
10732            H = random.randint(3, IH + 2)
10733            W = random.randint(3, IW + 2)
10734            test_shape(0, C, IH, IW, H, W, mode, padding_mode, align_corners)
10735
10736        for mode in ('bilinear', 'nearest'):
10737            for padding_mode in ('zeros', 'reflection'):
10738                for align_corners in (True, False):
10739                    # test known input
10740                    input = torch.arange(1., 11, device="mps").view(1, 1, 2, 5)
10741                    grid = torch.tensor(
10742                        [[[-0.9, -4.1], [0, 0.2000], [1, -1], [-0.333, 1e-6], [0.5, 1.0]],
10743                         [[-1.0, -0.5], [0, 0.3333], [1, -1], [-0.200, 1e-6], [1.5, 0.5]]], device="mps").view(1, 2, 5, 2)
10744                    if mode == 'bilinear':
10745                        if padding_mode == 'zeros':
10746                            if align_corners:
10747                                groundtruth = torch.tensor(
10748                                    [[0.0000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10749                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 0.0000]], device="mps").view(1, 1, 2, 5)
10750                            else:
10751                                groundtruth = torch.tensor(
10752                                    [[0.0000, 6.5000000000, 1.2500, 4.6675000191, 4.6250],
10753                                     [0.5000, 7.1665000916, 1.2500, 5.0000000000, 0.0000]], device="mps").view(1, 1, 2, 5)
10754                        elif padding_mode == 'border':
10755                            if align_corners:
10756                                groundtruth = torch.tensor(
10757                                    [[1.2000, 6.0000000000, 5.0000, 4.8340, 9.0000],
10758                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 8.7500]], device="mps").view(1, 1, 2, 5)
10759                            else:
10760                                groundtruth = torch.tensor(
10761                                    [[1.0000, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10762                                     [1.0000, 7.1665000916, 5.0000, 5.0000000000, 10.0000]], device="mps").view(1, 1, 2, 5)
10763                        elif padding_mode == 'reflection':
10764                            if align_corners:
10765                                groundtruth = torch.tensor(
10766                                    [[3.4500, 6.0000000000, 5.0000, 4.8340, 9.0000],
10767                                     [2.2500, 6.3332500450, 5.0000, 5.1000, 7.7500]], device="mps").view(1, 1, 2, 5)
10768                            else:
10769                                groundtruth = torch.tensor(
10770                                    [[3.0000004768, 6.5000000000, 5.0000, 4.6675000191, 9.2500],
10771                                     [1.0000000000, 7.1665000916, 5.0000, 5.0000000000, 9.2500]], device="mps").view(1, 1, 2, 5)
10772                        else:
10773                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
10774                    elif mode == 'nearest':
10775                        if padding_mode == 'zeros':
10776                            if align_corners:
10777                                groundtruth = torch.tensor(
10778                                    [[0., 8., 5., 7., 9.],
10779                                     [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10780                            else:
10781                                groundtruth = torch.tensor(
10782                                    [[0., 8., 5., 7., 0.],
10783                                     [1., 8., 5., 8., 0.]], device="mps").view(1, 1, 2, 5)
10784                        elif padding_mode == 'border':
10785                            if align_corners:
10786                                groundtruth = torch.tensor(
10787                                    [[1., 8., 5., 7., 9.],
10788                                     [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10789                            else:
10790                                groundtruth = torch.tensor(
10791                                    [[1., 8., 5., 7., 9.],
10792                                     [1., 8., 5., 8., 10.]], device="mps").view(1, 1, 2, 5)
10793                        elif padding_mode == 'reflection':
10794                            if align_corners:
10795                                groundtruth = torch.tensor(
10796                                    [[1., 8., 5., 7., 9.],
10797                                     [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10798                            else:
10799                                groundtruth = torch.tensor(
10800                                    [[1., 8., 5., 7., 9.],
10801                                     [1., 8., 5., 8., 9.]], device="mps").view(1, 1, 2, 5)
10802                        else:
10803                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
10804                    elif mode == 'bicubic':
10805                        if padding_mode == 'zeros':
10806                            if align_corners:
10807                                groundtruth = torch.tensor(
10808                                    [[-0.10424726, 7.1400003, 5.0000, 5.7842274, 9.0000],
10809                                     [2.4492188, 7.4814040, 5.0000, 6.0277520, 0.0000]], device="mps").view(1, 1, 2, 5)
10810                            else:
10811                                groundtruth = torch.tensor(
10812                                    [[0.00000, 7.6287503, 1.0625, 5.5977230, 5.3270264],
10813                                     [0.40625, 8.0288770, 1.0625, 5.9375067, -0.3515625]], device="mps").view(1, 1, 2, 5)
10814                        elif padding_mode == 'border':
10815                            if align_corners:
10816                                groundtruth = torch.tensor(
10817                                    [[1.1520010, 6.0599990, 5.0000, 4.870930, 9.0000000],
10818                                     [2.1328125, 6.4258375, 5.0000, 5.076003, 8.8671875]], device="mps").view(1, 1, 2, 5)
10819                            else:
10820                                groundtruth = torch.tensor(
10821                                    [[0.894531, 6.6050020, 4.625, 4.7138715, 9.800781],
10822                                     [0.906250, 7.2822485, 4.625, 5.0000052, 10.00000]], device="mps").view(1, 1, 2, 5)
10823                        elif padding_mode == 'reflection':
10824                            if align_corners:
10825                                groundtruth = torch.tensor(
10826                                    [[3.1822524, 6.239998, 5.0000, 4.8709273, 9.00000],
10827                                     [1.7812500, 6.703594, 5.0000, 5.0760007, 8.21875]], device="mps").view(1, 1, 2, 5)
10828                            else:
10829                                groundtruth = torch.tensor(
10830                                    [[2.7993753, 6.6050020, 4.25, 4.7138715, 10.269531],
10831                                     [0.8125000, 7.2822485, 4.25, 5.0000052, 9.332031]], device="mps").view(1, 1, 2, 5)
10832                        else:
10833                            raise AssertionError(f"missing groundtruth test for padding mode '{padding_mode}'")
10834
10835                    else:
10836                        raise AssertionError(f"missing groundtruth test for interpolation mode '{mode}'")
10837                    output = F.grid_sample(input, grid, mode=mode, padding_mode=padding_mode,
10838                                           align_corners=align_corners)
10839                    self.assertEqual(output, groundtruth, atol=1e-5, rtol=0,
10840                                     msg=f"groundtruth comparison failed for mode={mode}, "
10841                                     f"padding_mode={padding_mode}")
10842
10843class TestAdvancedIndexing(TestCaseMPS):
10844    supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
10845    supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
10846
10847    @unittest.skipIf(product_version < 14.0, "Skipped on macOS < 14")
10848    def test_nonzero_no_warning(self):
10849        device = "mps"
10850        t = torch.randn((2, 2), device=device)
10851        with warnings.catch_warnings(record=True) as w:
10852            warnings.simplefilter("always")
10853            torch.nonzero(t)
10854            t.nonzero()
10855            self.assertEqual(len(w), 0)
10856
10857    def test_nonzero(self):
10858        def helper(dtype):
10859            device = "mps"
10860            shapes = [
10861                torch.Size((12,)),
10862                torch.Size((12, 1)),
10863                torch.Size((1, 12)),
10864                torch.Size((6, 2)),
10865                torch.Size((3, 2, 2)),
10866                torch.Size((5, 5, 5)),
10867            ]
10868
10869            def gen_nontrivial_input(shape, dtype, device):
10870                if dtype != torch.bfloat16:
10871                    return torch.randint(2, shape, device=device, dtype=dtype)
10872                else:
10873                    # windows does not work for bfloat16 randing
10874                    return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
10875
10876            for shape in shapes:
10877                tensor = gen_nontrivial_input(shape, dtype, device)
10878                dst1 = torch.nonzero(tensor, as_tuple=False)
10879                dst2 = tensor.nonzero(as_tuple=False)
10880                dst3 = torch.empty([], dtype=torch.long, device=device)
10881                dst3 = dst3.resize_(0)
10882                torch.nonzero(tensor, out=dst3)
10883                np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
10884                np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
10885                self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
10886                self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
10887                self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
10888                tup1 = torch.nonzero(tensor, as_tuple=True)
10889                tup2 = tensor.nonzero(as_tuple=True)
10890                tup1 = torch.stack(tup1).t().cpu()
10891                tup2 = torch.stack(tup2).t().cpu()
10892                self.assertEqual(tup1, np_result, atol=0, rtol=0)
10893                self.assertEqual(tup2, np_result, atol=0, rtol=0)
10894        [helper(dtype) for dtype in self.supported_dtypes]
10895
10896    def test_nonzero_astuple_out(self):
10897        device = "mps"
10898        t = torch.randn((3, 3, 3), device=device)
10899        out = torch.empty([], dtype=torch.long, device=device)
10900        out = out.resize_(0)
10901
10902        with self.assertRaises(RuntimeError):
10903            torch.nonzero(t, as_tuple=True, out=out)
10904
10905        self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
10906
10907        # Verifies that JIT script cannot handle the as_tuple kwarg
10908        # See Issue https://github.com/pytorch/pytorch/issues/45499.
10909        def _foo(t):
10910            tuple_result = torch.nonzero(t, as_tuple=True)
10911            nontuple_result = torch.nonzero(t, as_tuple=False)
10912            out = torch.empty_like(nontuple_result)
10913            torch.nonzero(t, as_tuple=False, out=out)
10914            return tuple_result, nontuple_result, out
10915
10916        with self.assertRaises(RuntimeError):
10917            scripted_foo = torch.jit.script(_foo)
10918
10919        # Verifies that JIT tracing works fine
10920        traced_foo = torch.jit.trace(_foo, t)
10921        traced_tuple, traced_nontuple, traced_out = traced_foo(t)
10922        expected_tuple = torch.nonzero(t, as_tuple=True)
10923        expected_nontuple = torch.nonzero(t)
10924
10925        self.assertEqual(traced_tuple, expected_tuple)
10926        self.assertEqual(traced_nontuple, expected_nontuple)
10927        self.assertEqual(traced_out, expected_nontuple)
10928
10929    def test_nonzero_discontiguous(self):
10930        device = "mps"
10931        shape = (4, 4)
10932        tensor = torch.randint(2, shape, device=device)
10933        tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
10934        dst1 = tensor.nonzero(as_tuple=False)
10935        dst2 = tensor_nc.nonzero(as_tuple=False)
10936        self.assertEqual(dst1, dst2, atol=0, rtol=0)
10937        dst3 = torch.empty_like(dst1)
10938        data_ptr = dst3.data_ptr()
10939        # expect dst3 storage to be reused
10940        torch.nonzero(tensor, out=dst3)
10941        self.assertEqual(data_ptr, dst3.data_ptr())
10942        self.assertEqual(dst1, dst3, atol=0, rtol=0)
10943        # discontiguous out
10944        dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
10945        data_ptr = dst4.data_ptr()
10946        strides = dst4.stride()
10947        torch.nonzero(tensor, out=dst4)
10948        self.assertEqual(data_ptr, dst4.data_ptr())
10949        self.assertEqual(dst1, dst4, atol=0, rtol=0)
10950        self.assertEqual(strides, dst4.stride())
10951
10952    def test_nonzero_non_diff(self):
10953        device = "mps"
10954        x = torch.randn(10, requires_grad=True, device=device)
10955        nz = x.nonzero()
10956        self.assertFalse(nz.requires_grad)
10957
10958    def test_nonzero_multi_threading(self):
10959        # Test that MPS doesn't crash if nonzero called concurrently
10960        # See https://github.com/pytorch/pytorch/issues/100285
10961        x = torch.rand(3, 3, device="mps")
10962        t1 = threading.Thread(target=torch.nonzero, args=(x,))
10963        t2 = threading.Thread(target=torch.nonzero, args=(x,))
10964        t1.start()
10965        t2.start()
10966
10967    def test_sliced_view_cast(self):
10968        # This used to crash on MacOS Sequoia
10969        # See https://github.com/pytorch/pytorch/issues/137800
10970        x = torch.rand(16, 16, device='mps', dtype=torch.float16)
10971        y = x[:, 0:2].view(torch.float32) + 1
10972
10973    def test_masked_select(self):
10974        x = torch.randn(3, 4)
10975        x_mps = x.to("mps")
10976        mask = x.ge(0.5)
10977        mask_mps = x_mps.ge(0.5)
10978
10979        res = torch.masked_select(x, mask)
10980        res_mps = torch.masked_select(x_mps, mask_mps)
10981
10982        self.assertEqual(res, res_mps)
10983
10984    # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
10985    def test_indexing_get(self):
10986        def helper(dtype):
10987            x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
10988            x_mps = x_cpu.detach().clone().to("mps")
10989
10990            y_cpu = x_cpu[[0, 1, 2], [0, 1, 0]]
10991            y_mps = x_mps[[0, 1, 2], [0, 1, 0]]
10992            self.assertEqual(y_cpu, y_mps, str(dtype))
10993        [helper(dtype) for dtype in self.supported_dtypes]
10994
10995    def test_indexing_select_corners(self):
10996        def helper(dtype):
10997            x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
10998            x_mps = x_cpu.detach().clone().to("mps")
10999
11000            rows_cpu = torch.tensor([[0, 0], [3, 3]])
11001            rows_mps = rows_cpu.detach().clone().to("mps")
11002
11003            cols_cpu = torch.tensor([[0, 2], [0, 2]])
11004            cols_mps = cols_cpu.detach().clone().to("mps")
11005
11006            res_cpu = x_cpu[rows_cpu, cols_cpu]
11007            res_mps = x_mps[rows_mps, cols_mps]
11008
11009            self.assertEqual(res_cpu, res_mps, str(dtype))
11010        [helper(dtype) for dtype in self.supported_dtypes]
11011
11012    # FIXME: uint8 fails for this testcase, needs further debugging
11013    def test_slicing_using_advanced_index_for_column(self):
11014        def helper(dtype):
11015            x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
11016            x_mps = x_cpu.detach().clone().to("mps")
11017
11018            z_cpu = x_cpu[1:4, 1:3]
11019            z_mps = x_mps[1:4, 1:3]
11020            self.assertEqual(z_cpu, z_mps, str(dtype))
11021
11022            # using advanced index for column
11023            y_cpu = x_cpu[1:4, [1, 2]]
11024            y_mps = x_mps[1:4, [1, 2]]
11025            self.assertEqual(y_cpu, y_mps, str(dtype))
11026        # FIXME: use supported_dtypes once uint8 is fixed
11027        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
11028
11029    def test_boolean_array_indexing(self):
11030        def helper(dtype):
11031            x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
11032            x_mps = x_cpu.detach().clone().to("mps")
11033
11034            res_cpu = x_cpu[x_cpu > 5]
11035            res_mps = x_mps[x_mps > 5]
11036
11037            self.assertEqual(res_cpu, res_mps, str(dtype))
11038        for dtype in self.supported_dtypes:
11039            # MPS support binary op with uint8 natively starting from macOS 13.0
11040            if product_version < 13.0 and dtype == torch.uint8:
11041                continue
11042            helper(dtype)
11043
11044    def test_advanced_indexing_3D_get(self):
11045        def helper(x_cpu):
11046            x_mps = x_cpu.detach().clone().to("mps")
11047            self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
11048            self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
11049            self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
11050
11051        x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
11052                               [0.5, 0.6, 0.7, 0.8],
11053                               [0.9, 1.0, 1.1, 1.2],
11054                               [1.3, 1.4, 1.5, 1.6]],
11055
11056                              [[2.0, 2.1, 2.2, 2.3],
11057                               [2.4, 2.5, 2.6, 2.7],
11058                               [2.8, 2.9, 3.0, 3.1],
11059                               [3.2, 3.3, 3.4, 3.5]],
11060
11061                              [[4.0, 4.1, 4.2, 4.3],
11062                               [4.4, 4.5, 4.6, 4.7],
11063                               [4.8, 4.9, 5.0, 5.1],
11064                               [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
11065        helper(x_cpu)
11066        for idx in range(len(self.supported_np_dtypes)):
11067            # torch.randn / torch.rand don't work with all dtypes
11068            # Generate input data for all dtypes on Numpy them move to torch
11069            input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
11070            inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
11071
11072            helper(inputCPU)
11073
11074    def test_advanced_indexing_3D_put(self):
11075        def helper(x_cpu):
11076            dtype = x_cpu.dtype
11077            x_mps = x_cpu.detach().clone().to("mps")
11078
11079            out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
11080            out_tensor_cpu_view = out_tensor_cpu[1:]
11081
11082            out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
11083            out_tensor_mps_view = out_tensor_mps[1:]
11084
11085            x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
11086            x_mps[[1, 2], 3, :] = out_tensor_mps_view
11087            self.assertEqual(x_cpu, x_mps)
11088
11089            x_cpu[[0, 2], :, :] = out_tensor_cpu_view
11090            x_mps[[0, 2], :, :] = out_tensor_mps_view
11091            self.assertEqual(x_cpu, x_mps)
11092
11093            x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
11094            x_mps[:, [1, 0], [1]] = out_tensor_mps_view
11095            self.assertEqual(x_cpu, x_mps)
11096
11097        x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
11098                               [0.5, 0.6, 0.7, 0.8],
11099                               [0.9, 1.0, 1.1, 1.2],
11100                               [1.3, 1.4, 1.5, 1.6]],
11101
11102                              [[2.0, 2.1, 2.2, 2.3],
11103                               [2.4, 2.5, 2.6, 2.7],
11104                               [2.8, 2.9, 3.0, 3.1],
11105                               [3.2, 3.3, 3.4, 3.5]],
11106
11107                              [[4.0, 4.1, 4.2, 4.3],
11108                               [4.4, 4.5, 4.6, 4.7],
11109                               [4.8, 4.9, 5.0, 5.1],
11110                               [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
11111        helper(x_cpu)
11112        for idx in range(len(self.supported_np_dtypes)):
11113            # torch.randn / torch.rand don't work with all dtypes
11114            # Generate input data for all dtypes on Numpy them move to torch
11115            input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
11116            inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
11117
11118            helper(inputCPU)
11119
11120    def test_index_put_with_view_indices(self):
11121        def helper(dtype):
11122            target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
11123            target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
11124
11125            indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
11126            indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
11127
11128            value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
11129            value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
11130
11131            target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
11132            target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
11133
11134            self.assertEqual(target_cpu, target_mps)
11135
11136        [helper(dtype) for dtype in [torch.int32, torch.float]]
11137
11138    # tests from 'test_indexing.py'
11139    def test_advancedindex_big(self, device="mps"):
11140        reference = torch.arange(0, 123344, dtype=torch.int, device=device)
11141
11142        self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
11143                         torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
11144
11145    def test_set_item_to_scalar_tensor(self, device="mps"):
11146        m = random.randint(1, 10)
11147        n = random.randint(1, 10)
11148        z = torch.randn([m, n], device=device)
11149        a = 1.0
11150        w = torch.tensor(a, requires_grad=True, device=device)
11151        z[:, 0] = w
11152        z.sum().backward()
11153        self.assertEqual(w.grad, m * a)
11154
11155    def test_single_int(self, device="mps"):
11156        v = torch.randn(5, 7, 3, device=device)
11157        self.assertEqual(v[4].shape, (7, 3))
11158
11159    def test_multiple_int(self, device="mps"):
11160        v = torch.randn(5, 7, 3, device=device)
11161        self.assertEqual(v[4].shape, (7, 3))
11162        self.assertEqual(v[4, :, 1].shape, (7,))
11163
11164    def test_none(self, device="mps"):
11165        v = torch.randn(5, 7, 3, device=device)
11166        self.assertEqual(v[None].shape, (1, 5, 7, 3))
11167        self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
11168        self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
11169        self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
11170
11171    def test_step(self, device="mps"):
11172        v = torch.arange(10, device=device)
11173        self.assertEqual(v[::1], v)
11174        self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
11175        self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
11176        self.assertEqual(v[::11].tolist(), [0])
11177        self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
11178
11179    def test_step_assignment(self, device="mps"):
11180        v = torch.zeros(4, 4, device=device)
11181        v[0, 1::2] = torch.tensor([3., 4.], device=device)
11182        self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
11183        self.assertEqual(v[1:].sum(), 0)
11184
11185    def test_bool_indices(self, device="mps"):
11186        v = torch.randn(5, 7, 3, device=device)
11187        boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
11188        self.assertEqual(v[boolIndices].shape, (3, 7, 3))
11189        self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
11190
11191        v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
11192        boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
11193        uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
11194        with warnings.catch_warnings(record=True) as w:
11195            self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
11196            self.assertEqual(v[boolIndices], v[uint8Indices])
11197            self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
11198            self.assertEqual(len(w), 2)
11199
11200    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
11201    def test_bool_indices_accumulate(self, device="mps"):
11202        mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
11203        mask = mask > 0
11204        y = torch.ones(size=(10, 10), device=device)
11205        y.index_put_((mask, ), y[mask], accumulate=True)
11206        self.assertEqual(y, torch.ones(size=(10, 10), device=device))
11207
11208    def test_multiple_bool_indices(self, device="mps"):
11209        v = torch.randn(5, 7, 3, device=device)
11210        # note: these broadcast together and are transposed to the first dim
11211        mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
11212        mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
11213        self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
11214
11215    def test_byte_mask(self, device="mps"):
11216        v = torch.randn(5, 7, 3, device=device)
11217        mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
11218        with warnings.catch_warnings(record=True) as w:
11219            self.assertEqual(v[mask].shape, (3, 7, 3))
11220            self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
11221            self.assertEqual(len(w), 2)
11222
11223        v = torch.tensor([1.], device=device)
11224        self.assertEqual(v[v == 0], torch.tensor([], device=device))
11225
11226    def test_byte_mask_accumulate(self, device="mps"):
11227        mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
11228        y = torch.ones(size=(10, 10), device=device)
11229        with warnings.catch_warnings(record=True) as w:
11230            warnings.simplefilter("always")
11231            y.index_put_((mask, ), y[mask], accumulate=True)
11232            self.assertEqual(y, torch.ones(size=(10, 10), device=device))
11233            self.assertEqual(len(w), 2)
11234
11235    def test_index_put_accumulate_expanded_values(self, device="mps"):
11236        t = torch.zeros((5, 2))
11237        t_dev = t.to(device)
11238        indices = [
11239            torch.tensor([0, 1, 2, 3]),
11240            torch.tensor([1, ]),
11241        ]
11242        indices_dev = [i.to(device) for i in indices]
11243        values0d = torch.tensor(1.0)
11244        values1d = torch.tensor([1.0, ])
11245
11246        out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
11247        out_cpu = t.index_put_(indices, values0d, accumulate=True)
11248        self.assertEqual(out_mps.cpu(), out_cpu)
11249
11250        out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
11251        out_cpu = t.index_put_(indices, values1d, accumulate=True)
11252        self.assertEqual(out_mps.cpu(), out_cpu)
11253
11254        t = torch.zeros(4, 3, 2)
11255        t_dev = t.to(device)
11256
11257        indices = [
11258            torch.tensor([0, ]),
11259            torch.arange(3)[:, None],
11260            torch.arange(2)[None, :],
11261        ]
11262        indices_dev = [i.to(device) for i in indices]
11263        values1d = torch.tensor([-1.0, -2.0])
11264        values2d = torch.tensor([[-1.0, -2.0], ])
11265
11266        out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
11267        out_cpu = t.index_put_(indices, values1d, accumulate=True)
11268        self.assertEqual(out_mps.cpu(), out_cpu)
11269
11270        out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
11271        out_cpu = t.index_put_(indices, values2d, accumulate=True)
11272        self.assertEqual(out_mps.cpu(), out_cpu)
11273
11274    def test_index_put_accumulate_non_contiguous(self, device="mps"):
11275        t = torch.zeros((5, 2, 2))
11276        t_dev = t.to(device)
11277        t1 = t_dev[:, 0, :]
11278        t2 = t[:, 0, :]
11279        self.assertFalse(t1.is_contiguous())
11280        self.assertFalse(t2.is_contiguous())
11281
11282        indices = [torch.tensor([0, 1]), ]
11283        indices_dev = [i.to(device) for i in indices]
11284        value = torch.randn(2, 2)
11285        out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
11286        out_cpu = t2.index_put_(indices, value, accumulate=True)
11287        self.assertFalse(t1.is_contiguous())
11288        self.assertFalse(t2.is_contiguous())
11289
11290        self.assertEqual(out_mps.cpu(), out_cpu)
11291
11292    def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
11293        # TODO: replace with a better solution.
11294        # Currently, here using torchscript to put None into indices.
11295        # on C++ it gives indices as a list of 2 optional tensors: first is null and
11296        # the second is a valid tensor.
11297        @torch.jit.script
11298        def func(x, i, v):
11299            idx = [None, i]
11300            x.index_put_(idx, v, accumulate=True)
11301            return x
11302
11303        n = 4
11304        t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
11305        t_dev = t.to(device)
11306        indices = torch.tensor([1, 0])
11307        indices_dev = indices.to(device)
11308        value0d = torch.tensor(10.0)
11309        value1d = torch.tensor([1.0, 2.0])
11310
11311        out_mps = func(t_dev, indices_dev, value0d.to("mps"))
11312        out_cpu = func(t, indices, value0d)
11313        self.assertEqual(out_mps.cpu(), out_cpu)
11314
11315        out_mps = func(t_dev, indices_dev, value1d.to("mps"))
11316        out_cpu = func(t, indices, value1d)
11317        self.assertEqual(out_mps.cpu(), out_cpu)
11318
11319    def test_index_put_accumulate_duplicate_indices(self, device="mps"):
11320        for i in range(1, 128):
11321            # generate indices by random walk, this will create indices with
11322            # lots of duplicates interleaved with each other
11323            delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
11324
11325            indices = delta.cumsum(0).long().to("mps")
11326
11327            # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
11328            input = torch.randn(indices.cpu().abs().max().to("mps") + 1, device=device)
11329            values = torch.randn(indices.size(0), device=device)
11330            output = input.index_put((indices,), values, accumulate=True)
11331
11332            input_list = input.tolist()
11333            indices_list = indices.tolist()
11334            values_list = values.tolist()
11335            for i, v in zip(indices_list, values_list):
11336                input_list[i] += v
11337
11338            self.assertEqual(output, input_list)
11339
11340    def test_index_put_deterministic(self, device="mps"):
11341        def helper(dtype, accumulate, deterministic, num_tests=128):
11342            acc_expected = torch.tensor([233, 187, 360], device=device, dtype=dtype)
11343            non_acc_expected = torch.tensor([38, 37, 39], device=device, dtype=dtype)
11344            t_idx = torch.tensor(
11345                [0, 0, 0, 0, 2, 2, 1, 0, 2, 1, 0, 1, 2, 1, 0, 2, 2, 2, 2, 2,
11346                 0, 0, 2, 1, 2, 1, 0, 0, 2, 0, 2, 1, 1, 2, 2, 0, 2, 1, 0, 2]
11347            )
11348            for _ in range(num_tests):
11349                try:
11350                    torch.use_deterministic_algorithms(deterministic)
11351                    t = torch.zeros(3, dtype=dtype, device=device)
11352                    t.index_put_((t_idx,), torch.arange(len(t_idx), device=device, dtype=dtype), accumulate=accumulate)
11353                    if accumulate:
11354                        self.assertEqual(t, acc_expected)
11355                    else:
11356                        self.assertEqual(t, non_acc_expected)
11357                finally:
11358                    torch.use_deterministic_algorithms(False)
11359
11360        for accumulate, deterministic in product((False, True), (False, True)):
11361            dtype = torch.float if accumulate else torch.long
11362            if not accumulate and not deterministic:
11363                with self.assertRaisesRegex(AssertionError, "Tensor-likes are not equal!"):
11364                    helper(dtype, accumulate, deterministic)
11365            else:
11366                helper(dtype, accumulate, deterministic)
11367
11368    def test_multiple_byte_mask(self, device="mps"):
11369        v = torch.randn(5, 7, 3, device=device)
11370        # note: these broadcast together and are transposed to the first dim
11371        mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
11372        mask2 = torch.ByteTensor([1, 1, 1]).to(device)
11373        with warnings.catch_warnings(record=True) as w:
11374            warnings.simplefilter("always")
11375            self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
11376            self.assertEqual(len(w), 2)
11377
11378    def test_byte_mask2d(self, device="mps"):
11379        v = torch.randn(5, 7, 3, device=device)
11380        c = torch.randn(5, 7, device=device)
11381        num_ones = (c > 0).sum()
11382        r = v[c > 0]
11383        self.assertEqual(r.shape, (num_ones, 3))
11384
11385    def test_jit_indexing(self, device="mps"):
11386        def fn1(x):
11387            x[x < 50] = 1.0
11388            return x
11389
11390        def fn2(x):
11391            x[0:50] = 1.0
11392            return x
11393
11394        scripted_fn1 = torch.jit.script(fn1)
11395        scripted_fn2 = torch.jit.script(fn2)
11396        data = torch.arange(100, device=device, dtype=torch.float)
11397        out = scripted_fn1(data.detach().clone())
11398        ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
11399        self.assertEqual(out, ref)
11400        out = scripted_fn2(data.detach().clone())
11401        self.assertEqual(out, ref)
11402
11403    def test_int_indices(self, device="mps"):
11404        v = torch.randn(5, 7, 3, device=device)
11405        self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
11406        self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
11407        self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
11408
11409    def test_index_put_src_datatype(self):
11410        def helper(device, dtype):
11411            src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11412            vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
11413            indices = (torch.tensor([0, 2, 1]),)
11414            res = src.index_put_(indices, vals, accumulate=True)
11415            self.assertEqual(res.shape, src.shape)
11416        [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
11417
11418    @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
11419    def test_index_src_datatype(self):
11420        def helper(device, dtype):
11421            orig_dtype = dtype
11422            if dtype is torch.bool:
11423                dtype = torch.uint8
11424
11425            src = torch.ones(3, 2, 4, device=device, dtype=dtype)
11426            if orig_dtype is torch.bool:
11427                src = src == 1
11428            # test index
11429            res = src[[0, 2, 1], :, :]
11430            self.assertEqual(res.shape, src.shape)
11431            # test index_put, no accum
11432            src[[0, 2, 1], :, :] = res
11433            self.assertEqual(res.shape, src.shape)
11434        [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
11435
11436    def test_int_indices2d(self, device="mps"):
11437        # From the NumPy indexing example
11438        x = torch.arange(0, 12, device=device).view(4, 3)
11439        rows = torch.tensor([[0, 0], [3, 3]], device=device)
11440        columns = torch.tensor([[0, 2], [0, 2]], device=device)
11441        self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
11442
11443    def test_int_indices_broadcast(self, device="mps"):
11444        # From the NumPy indexing example
11445        x = torch.arange(0, 12, device=device).view(4, 3)
11446        rows = torch.tensor([0, 3], device=device)
11447        columns = torch.tensor([0, 2], device=device)
11448        result = x[rows[:, None], columns]
11449        self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
11450
11451    def test_empty_index(self, device="mps"):
11452        x = torch.arange(0, 12, device=device).view(4, 3)
11453        idx = torch.tensor([], dtype=torch.long, device=device)
11454        self.assertEqual(x[idx].numel(), 0)
11455
11456        # empty assignment should have no effect but not throw an exception
11457        y = x.clone()
11458        y[idx] = -1
11459        self.assertEqual(x, y)
11460
11461        mask = torch.zeros(4, 3, device=device).bool()
11462        y[mask] = -1
11463        self.assertEqual(x, y)
11464
11465    def test_empty_ndim_index(self, device="mps"):
11466        x = torch.randn(5, device=device)
11467        self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
11468
11469        x = torch.randn(2, 3, 4, 5, device=device)
11470        self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
11471                         x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
11472
11473        x = torch.empty(10, 0, device=device)
11474        self.assertEqual(x[[1, 2]].shape, (2, 0))
11475        self.assertEqual(x[[], []].shape, (0,))
11476        with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
11477            x[:, [0, 1]]
11478
11479    def test_empty_ndim_index_bool(self, device="mps"):
11480        x = torch.randn(5, device=device)
11481        self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
11482
11483    def test_empty_slice(self, device="mps"):
11484        x = torch.randn(2, 3, 4, 5, device=device)
11485        y = x[:, :, :, 1]
11486        z = y[:, 1:1, :]
11487        self.assertEqual((2, 0, 4), z.shape)
11488        # this isn't technically necessary, but matches NumPy stride calculations.
11489        self.assertEqual((60, 20, 5), z.stride())
11490        self.assertTrue(z.is_contiguous())
11491
11492    def test_index_getitem_copy_bools_slices(self, device="mps"):
11493        true = torch.tensor(1, dtype=torch.uint8, device=device)
11494        false = torch.tensor(0, dtype=torch.uint8, device=device)
11495
11496        tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
11497
11498        for a in tensors:
11499            self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
11500            self.assertEqual(torch.empty(0, *a.shape), a[False])
11501            self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
11502            self.assertEqual(torch.empty(0, *a.shape), a[false])
11503            self.assertEqual(a.data_ptr(), a[None].data_ptr())
11504            self.assertEqual(a.data_ptr(), a[...].data_ptr())
11505
11506    def test_index_setitem_bools_slices(self, device="mps"):
11507        true = torch.tensor(1, dtype=torch.uint8, device=device)
11508        false = torch.tensor(0, dtype=torch.uint8, device=device)
11509
11510        tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
11511
11512        for a in tensors:
11513            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
11514            # (some of these ops already prefix a 1 to the size)
11515            neg_ones = torch.ones_like(a) * -1
11516            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
11517            a[True] = neg_ones_expanded
11518            self.assertEqual(a, neg_ones)
11519            a[False] = 5
11520            self.assertEqual(a, neg_ones)
11521            a[true] = neg_ones_expanded * 2
11522            self.assertEqual(a, neg_ones * 2)
11523            a[false] = 5
11524            self.assertEqual(a, neg_ones * 2)
11525            a[None] = neg_ones_expanded * 3
11526            self.assertEqual(a, neg_ones * 3)
11527            a[...] = neg_ones_expanded * 4
11528            self.assertEqual(a, neg_ones * 4)
11529            if a.dim() == 0:
11530                with self.assertRaises(IndexError):
11531                    a[:] = neg_ones_expanded * 5
11532
11533    def test_index_scalar_with_bool_mask(self, device="mps"):
11534        a = torch.tensor(1, device=device)
11535        uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
11536        boolMask = torch.tensor(True, dtype=torch.bool, device=device)
11537        self.assertEqual(a[uintMask], a[boolMask])
11538        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11539
11540        a = torch.tensor(True, dtype=torch.bool, device=device)
11541        self.assertEqual(a[uintMask], a[boolMask])
11542        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
11543
11544    def test_setitem_expansion_error(self, device="mps"):
11545        true = torch.tensor(True, device=device)
11546        a = torch.randn(2, 3, device=device)
11547        # check prefix with  non-1s doesn't work
11548        a_expanded = a.expand(torch.Size([5, 1]) + a.size())
11549        # NumPy: ValueError
11550        with self.assertRaises(RuntimeError):
11551            a[True] = a_expanded
11552        with self.assertRaises(RuntimeError):
11553            a[true] = a_expanded
11554
11555    def test_getitem_scalars(self, device="mps"):
11556        zero = torch.tensor(0, dtype=torch.int64, device=device)
11557        one = torch.tensor(1, dtype=torch.int64, device=device)
11558
11559        # non-scalar indexed with scalars
11560        a = torch.randn(2, 3, device=device)
11561        self.assertEqual(a[0], a[zero])
11562        self.assertEqual(a[0][1], a[zero][one])
11563        self.assertEqual(a[0, 1], a[zero, one])
11564        self.assertEqual(a[0, one], a[zero, 1])
11565
11566        # indexing by a scalar should slice (not copy)
11567        self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
11568        self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
11569        self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
11570
11571        # scalar indexed with scalar
11572        r = torch.randn((), device=device)
11573        with self.assertRaises(IndexError):
11574            r[:]
11575        with self.assertRaises(IndexError):
11576            r[zero]
11577        self.assertEqual(r, r[...])
11578
11579    def test_setitem_scalars(self, device="mps"):
11580        zero = torch.tensor(0, dtype=torch.int64)
11581
11582        # non-scalar indexed with scalars
11583        a = torch.randn(2, 3, device=device)
11584        a_set_with_number = a.clone()
11585        a_set_with_scalar = a.clone()
11586        b = torch.randn(3, device=device)
11587
11588        a_set_with_number[0] = b
11589        a_set_with_scalar[zero] = b
11590        self.assertEqual(a_set_with_number, a_set_with_scalar)
11591        a[1, zero] = 7.7
11592        self.assertEqual(7.7, a[1, 0])
11593
11594        # scalar indexed with scalars
11595        r = torch.randn((), device=device)
11596        with self.assertRaises(IndexError):
11597            r[:] = 8.8
11598        with self.assertRaises(IndexError):
11599            r[zero] = 8.8
11600        r[...] = 9.9
11601        self.assertEqual(9.9, r)
11602
11603    def test_basic_advanced_combined(self, device="mps"):
11604        # From the NumPy indexing example
11605        x = torch.arange(0, 12, device=device).view(4, 3)
11606        self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
11607        self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
11608
11609        # Check that it is a copy
11610        unmodified = x.clone()
11611        x[1:2, [1, 2]].zero_()
11612        self.assertEqual(x, unmodified)
11613
11614        # But assignment should modify the original
11615        unmodified = x.clone()
11616        x[1:2, [1, 2]] = 0
11617        self.assertNotEqual(x, unmodified)
11618
11619    def test_int_assignment(self, device="mps"):
11620        x = torch.arange(0, 4, device=device).view(2, 2)
11621        x[1] = 5
11622        self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
11623
11624        x = torch.arange(0, 4, device=device).view(2, 2)
11625        x[1] = torch.arange(5, 7, device=device)
11626        self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
11627
11628    def test_byte_tensor_assignment(self, device="mps"):
11629        x = torch.arange(0., 16, device=device).view(4, 4)
11630        b = torch.ByteTensor([True, False, True, False]).to(device)
11631        value = torch.tensor([3., 4., 5., 6.], device=device)
11632
11633        with warnings.catch_warnings(record=True) as w:
11634            x[b] = value
11635            self.assertEqual(len(w), 1)
11636
11637        self.assertEqual(x[0], value)
11638        self.assertEqual(x[1], torch.arange(4., 8, device=device))
11639        self.assertEqual(x[2], value)
11640        self.assertEqual(x[3], torch.arange(12., 16, device=device))
11641
11642    def test_variable_slicing(self, device="mps"):
11643        x = torch.arange(0, 16, device=device).view(4, 4)
11644        indices = torch.IntTensor([0, 1]).to(device)
11645        i, j = indices
11646        self.assertEqual(x[i:j], x[0:1])
11647
11648    def test_ellipsis_tensor(self, device="mps"):
11649        x = torch.arange(0, 9, device=device).view(3, 3)
11650        idx = torch.tensor([0, 2], device=device)
11651        self.assertEqual(x[..., idx].tolist(), [[0, 2],
11652                                                [3, 5],
11653                                                [6, 8]])
11654        self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
11655                                                [6, 7, 8]])
11656
11657    def test_invalid_index(self, device="mps"):
11658        x = torch.arange(0, 16, device=device).view(4, 4)
11659        self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
11660
11661    def test_out_of_bound_index(self, device="mps"):
11662        x = torch.arange(0, 100, device=device).view(2, 5, 10)
11663        self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
11664        self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
11665        self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
11666                               lambda: x[0, 1, 15])
11667        self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
11668                               lambda: x[:, :, 12])
11669
11670    def test_zero_dim_index(self, device="mps"):
11671        x = torch.tensor(10, device=device)
11672        self.assertEqual(x, x.item())
11673
11674        def runner():
11675            print(x[0])
11676            return x[0]
11677
11678        self.assertRaisesRegex(IndexError, 'invalid index', runner)
11679
11680    def test_cpu_indices(self, device="mps"):
11681        idx = torch.tensor([0, 1])
11682        b = torch.zeros(2, device=device)
11683        x = torch.ones(10, device=device)
11684        x[idx] = b  # index_put_
11685        ref = torch.ones(10, device=device)
11686        ref[:2] = 0
11687        self.assertEqual(x, ref, atol=0, rtol=0)
11688        out = x[idx]  # index
11689        self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
11690
11691    def test_nextafter(self, device="mps"):
11692        for dtype in [torch.float16, torch.float32]:
11693            x = torch.tensor([1, -1, 0, 0, 2, -2], device=device, dtype=dtype)
11694            y = torch.tensor([2, -2, -1, 1, -3, 3], device=device, dtype=dtype)
11695            na = torch.nextafter(x, y)
11696            na_cpu = torch.nextafter(x.cpu(), y.cpu())
11697            na_ge_x_mps = na.cpu() > x.cpu()
11698            # greater is broken on MPS, see https://github.com/pytorch/pytorch/issues/125051
11699            na_ge_x_cpu = na_cpu > x.cpu()
11700            self.assertEqual(na_ge_x_mps, na_ge_x_cpu)
11701
11702
11703class TestRNNMPS(TestCaseMPS):
11704    def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
11705                     seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
11706        rnn = nn.LSTM(
11707            input_size=input_size,
11708            hidden_size=hidden_size,
11709            num_layers=num_layers,
11710            bias=bias,
11711            bidirectional=bidirectional,
11712            batch_first=batch_first,
11713            device="cpu"
11714        )
11715        bidirectional_mul = 2 if bidirectional else 1
11716
11717        if batch_first:
11718            input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11719            hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11720                             requires_grad=backward)
11721            cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11722                             requires_grad=backward)
11723        else:
11724            input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
11725            hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11726                             requires_grad=backward)
11727            cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
11728                             requires_grad=backward)
11729
11730        cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
11731
11732        rnn = rnn.to(device)
11733        input = input.to(device)
11734        hx = hx.to(device)
11735        cx = cx.to(device)
11736        output, (hn, cn) = rnn(input, (hx, cx))
11737
11738        self.assertEqual(cpu_output, output)
11739        self.assertEqual(cpu_hn, hn)
11740        self.assertEqual(cpu_cn, cn)
11741
11742        def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
11743            rnn = rnn.to(device)
11744            inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
11745
11746            output, (hx_out, cx_out) = rnn(inp, (hx, cx))
11747            assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
11748
11749            f = 0
11750            if output_grad_presented:
11751                f = f + 3 * output.sum()
11752            if states_grad_presented:
11753                f = f + (hx_out * cx_out).sum()
11754
11755            param_names, params = zip(*rnn.named_parameters())
11756            param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
11757
11758            input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
11759            return output, param_grads, input_grad, hx_grad, cx_grad
11760
11761        if backward:
11762            grad_cases = [
11763                dict(output_grad_presented=True, states_grad_presented=True),
11764                dict(output_grad_presented=False, states_grad_presented=True),
11765                dict(output_grad_presented=True, states_grad_presented=False),
11766            ]
11767
11768            for grad_case in grad_cases:
11769                cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
11770                    get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
11771                mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
11772                    get_backward_results(rnn, device, input, hx, cx, **grad_case)
11773
11774                self.assertEqual(cpu_hx_grad, mps_hx_grad)
11775                self.assertEqual(cpu_cx_grad, mps_cx_grad)
11776                self.assertEqual(cpu_output, mps_output)
11777                self.assertEqual(cpu_input_grad, mps_input_grad)
11778                for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
11779                    self.assertEqual(cpu_weight_grad, mps_weight_grad,
11780                                     f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
11781
11782    LSTM_TEST_CASES = [
11783        {},  # default
11784        dict(batch_first=True),
11785        dict(bias=False),
11786        dict(bidirectional=True),
11787        dict(batch_first=True, bias=False),
11788        dict(bidirectional=True, bias=False),
11789        dict(bidirectional=True, batch_first=True),
11790        dict(bidirectional=True, batch_first=True, bias=False)
11791    ]
11792
11793    def test_lstm_forward(self, device="mps", dtype=torch.float32):
11794        for num_layers in [1, 2, 5]:
11795            for test_options in self.LSTM_TEST_CASES:
11796                self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
11797
11798    def test_lstm_backward(self, device="mps", dtype=torch.float32):
11799        for num_layers in [1, 2, 5]:
11800            for test_options in self.LSTM_TEST_CASES:
11801                self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
11802
11803    def test_RNN_cell_no_broadcasting(self):
11804        def test(cell_module, input, hx, input_size, hidden_size):
11805            cell = cell_module(input_size, hidden_size, device='mps')
11806            self.assertRaises(RuntimeError, lambda: cell(input, hx))
11807
11808        def test_all(hidden_size, bad_hx, good_hx, input_size, input):
11809            test(nn.RNNCell, input, bad_hx, input_size, hidden_size)
11810            test(nn.GRUCell, input, bad_hx, input_size, hidden_size)
11811            test(nn.LSTMCell, input, (bad_hx, good_hx), input_size, hidden_size)
11812            test(nn.LSTMCell, input, (good_hx, bad_hx), input_size, hidden_size)
11813
11814        hidden_size = 20
11815        input_size = 10
11816        input = torch.randn(3, input_size, device='mps')
11817        bad_hx = torch.randn(1, hidden_size, device='mps')
11818        good_hx = torch.randn(3, hidden_size, device='mps')
11819
11820        # Test hidden/input batch size broadcasting
11821        test_all(hidden_size, bad_hx, good_hx, input_size, input)
11822
11823        # Test hx's hidden_size vs module's hidden_size broadcasting
11824        bad_hx = torch.randn(3, 1)
11825        test_all(hidden_size, bad_hx, good_hx, input_size, input)
11826
11827        # Test input's input_size vs module's input_size broadcasting
11828        bad_input = torch.randn(3, 1)
11829        test_all(hidden_size, good_hx, good_hx, input_size, bad_input)
11830
11831    def test_LSTM_cell(self):
11832        # this is just a smoke test; these modules are implemented through
11833        # autograd so no Jacobian test is needed
11834        for bias in (True, False):
11835            input = torch.randn(3, 10, device='mps')
11836            hx = torch.randn(3, 20, device='mps')
11837            cx = torch.randn(3, 20, device='mps')
11838            lstm = nn.LSTMCell(10, 20, bias=bias, device='mps')
11839            for _ in range(6):
11840                hx, cx = lstm(input, (hx, cx))
11841
11842            (hx + cx).sum().backward()
11843
11844    def test_LSTM_cell_forward_input_size(self):
11845        input = torch.randn(3, 11, device='mps')
11846        hx = torch.randn(3, 20, device='mps')
11847        cx = torch.randn(3, 20, device='mps')
11848        lstm = nn.LSTMCell(10, 20, device='mps')
11849        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11850
11851    def test_LSTM_cell_forward_hidden_size(self):
11852        input = torch.randn(3, 10, device='mps')
11853        hx = torch.randn(3, 21, device='mps')
11854        cx = torch.randn(3, 20, device='mps')
11855        lstm = nn.LSTMCell(10, 20, device='mps')
11856        self.assertRaises(Exception, lambda: lstm(input, (hx, cx)))
11857        self.assertRaises(Exception, lambda: lstm(input, (cx, hx)))
11858
11859
11860class TestFallbackWarning(TestCase):
11861    # TODO: Remove once test_testing.py is running on MPS devices
11862    def test_no_warning_on_import(self):
11863        out = subprocess.check_output(
11864            [sys.executable, "-W", "always", "-c", "import torch"],
11865            stderr=subprocess.STDOUT,
11866            # On Windows, opening the subprocess with the default CWD makes `import torch`
11867            # fail, so just set CWD to this script's directory
11868            cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
11869        self.assertEqual(out, "")
11870
11871    def _get_not_implemented_op(self):
11872        # This can be changed once we actually implement 'lcm'
11873        # Should return fn, args, kwargs, string_version
11874        return (torch.lcm,
11875                [torch.tensor([1], device='mps'), torch.tensor([2], device='mps')], {},
11876                "torch.lcm(torch.tensor([1], device='mps'), torch.tensor([2], device='mps'))")
11877
11878    def test_error_on_not_implemented(self):
11879        fn, args, kwargs, _ = self._get_not_implemented_op()
11880
11881        with self.assertRaisesRegex(NotImplementedError, "not currently implemented for the MPS device"):
11882            fn(*args, **kwargs)
11883
11884    def test_warn_on_not_implemented_with_fallback(self):
11885        _, _, _, op = self._get_not_implemented_op()
11886        script = f"""
11887import os
11888# MUST happen before pytorch's import
11889os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11890import warnings
11891
11892with warnings.catch_warnings(record=True) as w:
11893    import torch
11894
11895if len(w) > 0:
11896    print(w)
11897    exit(1)
11898
11899# This should run just fine and raise warning about perf
11900with warnings.catch_warnings(record=True) as w:
11901    {op}
11902
11903if len(w) != 1:
11904    print(w)
11905    exit(2)
11906"""
11907        try:
11908            subprocess.check_output(
11909                [sys.executable, '-W', 'always', '-c', script],
11910                stderr=subprocess.STDOUT,
11911                # On Windows, opening the subprocess with the default CWD makes `import torch`
11912                # fail, so just set CWD to this script's directory
11913                cwd=os.path.dirname(os.path.realpath(__file__)),)
11914        except subprocess.CalledProcessError as e:
11915            if e.returncode == 1:
11916                self.assertTrue(False, "There was a warning when importing torch when PYTORCH_ENABLE_MPS_FALLBACK is set." +
11917                                       e.output.decode("utf-8"))
11918            elif e.returncode == 2:
11919                self.assertTrue(False, "There wasn't exactly one warning when running not implemented op with "
11920                                f"PYTORCH_ENABLE_MPS_FALLBACK set. {e.output}")
11921            else:
11922                self.assertTrue(False, "Running a not implemented op failed even though PYTORCH_ENABLE_MPS_FALLBACK is set. " +
11923                                       e.output.decode("utf-8"))
11924
11925class TestNoRegression(TestCase):
11926    def test_assert_close(self):
11927        a = torch.ones(1, device="mps")
11928        b = torch.zeros(1, device="mps")
11929        inf = a / b
11930        nan = b / b
11931
11932        with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11933            torch.testing.assert_close(a, inf)
11934
11935        # TODO: The NaN test is failing when all the tests in test_mps are run
11936        # together but passes when run separately. There seems to be memory
11937        # corruption which needs to be fixed for this test to be enabled.
11938        # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
11939            # torch.testing.assert_close(a, nan)
11940
11941    def test_double_error(self):
11942        with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11943            a = torch.ones(2, dtype=torch.float64, device="mps")
11944
11945        a = torch.ones(2, device="mps")
11946        with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
11947            a = a.double()
11948
11949    def test_legacy_constructor(self):
11950        a = torch.ones(2, device="mps")
11951
11952        b = a.new(1)
11953
11954    def test_serialization_map_location(self):
11955
11956        # Ensures that cpu Tensor can be loaded on mps
11957        with tempfile.NamedTemporaryFile() as f:
11958            x = torch.rand(2)
11959            torch.save(x, f)
11960
11961            f.seek(0)
11962            x2 = torch.load(f, map_location="mps")
11963
11964            self.assertEqual(x, x2)
11965            self.assertEqual(x2.device.type, "mps")
11966
11967        # Ensures that mps Tensors can be loaded on mps
11968        with tempfile.NamedTemporaryFile() as f:
11969            x = torch.rand(2, device="mps")
11970            torch.save(x, f)
11971
11972            f.seek(0)
11973            x2 = torch.load(f)
11974
11975            self.assertEqual(x, x2)
11976            self.assertEqual(x2.device.type, "mps")
11977
11978        # Ensures that mps Tensors can be loaded on cpu
11979        with tempfile.NamedTemporaryFile() as f:
11980            x = torch.rand(2, device="mps")
11981            torch.save(x, f)
11982
11983            f.seek(0)
11984            x2 = torch.load(f, map_location="cpu")
11985
11986            self.assertEqual(x, x2)
11987            self.assertEqual(x2.device.type, "cpu")
11988
11989        # Ensures that `mps:0` Tensors can be loaded on mps
11990        with tempfile.NamedTemporaryFile() as f:
11991            x = torch.rand(2, device="mps:0")
11992            torch.save(x, f)
11993
11994            f.seek(0)
11995            x2 = torch.load(f, map_location="mps:0")
11996
11997            self.assertEqual(x, x2)
11998            self.assertEqual(x2.device.type, "mps")
11999
12000
12001MPS_DTYPES = get_all_dtypes()
12002for t in [torch.double, torch.cdouble, torch.cfloat, torch.bfloat16]:
12003    del MPS_DTYPES[MPS_DTYPES.index(t)]
12004
12005MPS_GRAD_DTYPES = [torch.float32, torch.float16]
12006
12007
12008class TestConsistency(TestCaseMPS):
12009    # TODO: This is only used while some ops are being added.
12010    # This list should contain all ops and dtypes eventually
12011    # This can be generated automatically in the `new_mps_allowlist.txt` file
12012    # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
12013    # You most likely do NOT want to modify this manually
12014
12015    FP16_LOW_PRECISION_LIST = {
12016        'add', 'sub', 'div', 'addcdiv',
12017        '__rdiv__', '__rmul__',
12018        'nn.functional.huber_loss',
12019        'true_divide', 'kron',
12020        'gradient', 'var', 'std', 'std_mean', 'ldexp',
12021        'linalg.vector_norm', 'lerp',
12022        'addr', 'var_mean',
12023        'var_mean_unbiased',
12024        'acosh', 'asinh', 'asin',
12025        'masked.std',
12026        'nn.functional.normalize',
12027        'nn.functional.triplet_margin_loss',
12028        'nn.functional.triplet_margin_with_distance_loss',
12029        'nn.functional.batch_norm',
12030        'nn.functional.instance_norm',
12031        'round', 'xlogy', 'addcmul',
12032        'nn.functional.cross_entropy',
12033        'nn.functional.binary_cross_entropy',
12034        'nn.functional.nll_loss',
12035        'nn.functional.max_pool2d',
12036        'nn.functional.gelu',
12037        'nn.functional.glu',
12038        '_native_batch_norm_legit',
12039        '_batch_norm_with_update',
12040        'native_batch_norm',
12041        'softmax',
12042        '_softmax_backward_data',
12043        'log_softmax',
12044        'masked.softmax',
12045        'masked.log_softmax',
12046        'masked.softmin',
12047        'nn.functional.kl_div',
12048        'nn.functional.softmin',
12049        'cross', 'linalg.cross',
12050        'prod', 'masked.prod',
12051        'nextafter',
12052        'native_layer_norm',
12053        'nn.functional.layer_norm',
12054        'nn.functional.interpolate',
12055        'nn.functional.upsample_bilinear',
12056        'nn.functional.upsample_nearest',
12057
12058        # for macOS 12
12059        'masked.normalize', 'masked.sum', 'masked.var',
12060        'outer',
12061        'sum_to_size', 'sum',
12062        'mul',
12063        'nansum', 'nanmean',
12064        'norm',
12065    }
12066
12067    FP32_LOW_PRECISION_LIST = {
12068        # conv2d and conv_transpose2d results have a very small
12069        # difference compared to CPU/CUDA, so we use lower precision on FP32
12070        'nn.functional.conv2d',
12071        'nn.functional.conv_transpose2d',
12072        'matmul', '__rmatmul__',
12073        'linalg.multi_dot',
12074        'addbmm',
12075    }
12076
12077    def _compute_tolerances(self, op, dtype):
12078        if (op.name in self.FP32_LOW_PRECISION_LIST) and dtype in [torch.float32, torch.complex64]:
12079            return (1e-4, 3e-5)
12080
12081        if op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16:
12082            return (1e-2, 1e-2)
12083
12084        if op.name in ['nn.functional.conv_transpose1d',
12085                       'nn.functional.conv_transpose2d',
12086                       'nn.functional.conv_transpose3d',
12087                       '__rmatmul__', 'addbmm', 'addmv',
12088                       'baddbmm', 'cov', 'matmul', 'mv'] and dtype == torch.float16:
12089            return (5e-2, 5e-2)
12090        if op.name == "masked.mean":
12091            return (7e-4, 2e-3)
12092        if op.name == "native_layer_norm":
12093            return (1e-4, 1.3e-5)
12094        if op.name in ["pow", "__rpow__"] and product_version < 13.3:
12095            # The result of pow(9 , 8) is showing 43046716, whereas it should've been 43046721.
12096            # fixed in macOS 13.3+
12097            return (1e-6, 2e-3 if dtype == torch.float16 else 4e-6)
12098        if op.name == "nn.functional.interpolate":
12099            return (1e-3, 1e-4)
12100        if op.name in ['fft.rfftn', 'fft.hfftn', 'fft.hfft2', 'fft.fft', 'fft.fftn', 'fft.rfft']:
12101            # TODO: Investigate why this is needed
12102            # See https://github.com/pytorch/pytorch/issues/120237
12103            return (3e-5, 3e-5)
12104        return (None, None)
12105
12106    # Used for accept mode only
12107    NEW_ALLOW_LIST = defaultdict(list)
12108    NEW_ALLOW_LIST_GRAD = defaultdict(list)
12109
12110    @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES + [torch.complex64])
12111    def test_output_match(self, device, dtype, op):
12112        self.assertEqual(device, "cpu")
12113
12114        def get_samples():
12115            return op.sample_inputs(
12116                device,
12117                dtype,
12118                requires_grad=(dtype.is_floating_point or dtype.is_complex),
12119                # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12120                set_seed=False,
12121            )
12122        cpu_samples = get_samples()
12123
12124        for cpu_sample in cpu_samples:
12125            #
12126            # Forward check
12127            #
12128            mps_sample = cpu_sample.transform(
12129                lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
12130
12131            cpu_args = [cpu_sample.input] + list(cpu_sample.args)
12132            cpu_kwargs = cpu_sample.kwargs
12133            mps_args = [mps_sample.input] + list(mps_sample.args)
12134            mps_kwargs = mps_sample.kwargs
12135
12136            # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
12137            if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
12138                mps_args[1] = cpu_args[1]
12139
12140            cpu_out = op(*cpu_args, **cpu_kwargs)
12141            mps_out = op(*mps_args, **mps_kwargs)
12142
12143            atol, rtol = self._compute_tolerances(op, dtype)
12144            if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8:
12145                atol = 1.0
12146                rtol = 0.0
12147
12148            self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
12149
12150
12151    @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
12152    def test_output_grad_match(self, device, dtype, op):
12153        self.assertEqual(device, "cpu")
12154
12155        def get_samples():
12156            return op.sample_inputs(
12157                device,
12158                dtype,
12159                requires_grad=(dtype.is_floating_point or dtype.is_complex),
12160                # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12161                set_seed=False,
12162            )
12163        cpu_samples = get_samples()
12164
12165        for cpu_sample in cpu_samples:
12166            #
12167            # Forward check
12168            #
12169            forward_failed = False
12170            mps_sample = cpu_sample.transform(
12171                lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x)
12172
12173            cpu_args = [cpu_sample.input] + list(cpu_sample.args)
12174            cpu_kwargs = cpu_sample.kwargs
12175            mps_args = [mps_sample.input] + list(mps_sample.args)
12176            mps_kwargs = mps_sample.kwargs
12177
12178            # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
12179            if op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor):
12180                mps_args[1] = cpu_args[1]
12181
12182            cpu_out = op(*cpu_args, **cpu_kwargs)
12183            mps_out = op(*mps_args, **mps_kwargs)
12184
12185            if op.name == "unique" and cpu_kwargs["sorted"] is False:
12186                continue
12187
12188            atol, rtol = self._compute_tolerances(op, dtype)
12189            if op.name in ["renorm", "norm", "linalg.norm"] and dtype == torch.float16:
12190                atol = 7e-4
12191                rtol = 1.5e-3
12192
12193            self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
12194
12195            #
12196            # Backward check
12197            #
12198            if forward_failed:
12199                # We would've failed immediately anyway, but this error is clearer
12200                # We error instead of continuing so that all_backward_pass would not be True
12201                raise RuntimeError("Forward pass already failed")
12202
12203            cpu_out = (cpu_out,) if isinstance(cpu_out, torch.Tensor) else tuple(cpu_out)
12204            mps_out = (mps_out,) if isinstance(mps_out, torch.Tensor) else tuple(mps_out)
12205
12206            def req_grad(t):
12207                return isinstance(t, torch.Tensor) and t.requires_grad
12208
12209            diff_cpu_out = tuple(t for t in cpu_out if req_grad(t))
12210            diff_mps_out = tuple(t for t in mps_out if req_grad(t))
12211            diff_cpu_arg = tuple(t for t in pytree.tree_leaves((cpu_args, cpu_kwargs)) if req_grad(t))
12212            diff_mps_arg = tuple(t for t in pytree.tree_leaves((mps_args, mps_kwargs)) if req_grad(t))
12213            self.assertEqual(len(diff_cpu_out), len(diff_mps_out))
12214            self.assertEqual(len(diff_cpu_arg), len(diff_mps_arg))
12215
12216            if len(diff_cpu_out) == 0:
12217                continue
12218            # rand_like does not work with certain dtypes, so cast to double and cast back
12219            cpu_grad_outputs = tuple(torch.rand_like(t, dtype=torch.double).to(dtype=t.dtype) for t in diff_cpu_out)
12220            mps_grad_outputs = tuple(t.to("mps") for t in cpu_grad_outputs)
12221
12222            # Compare computed gradients with cpu given random grad_output vector
12223            # Sometimes when the derivative is 0, we just don't bother creating the graph
12224            # allow_unused is needed in those cases.
12225            cpu_grad_inputs = torch.autograd.grad(diff_cpu_out, diff_cpu_arg, grad_outputs=cpu_grad_outputs, allow_unused=True)
12226            mps_grad_inputs = torch.autograd.grad(diff_mps_out, diff_mps_arg, grad_outputs=mps_grad_outputs, allow_unused=True)
12227
12228            self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
12229
12230
12231class TestErrorInputs(TestCase):
12232    _ignore_not_implemented_error = True
12233
12234    @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
12235    def test_error_inputs(self, device, op):
12236        self.assertEqual(device, "mps:0")
12237
12238        # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12239        mps_samples = op.error_inputs(device, set_seed=False)
12240
12241        for mps_sample in mps_samples:
12242            mps_sample_input = mps_sample.sample_input
12243            error_type = mps_sample.error_type
12244            error_regex = mps_sample.error_regex
12245
12246            mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
12247            mps_kwargs = mps_sample_input.kwargs
12248
12249            # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
12250            if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
12251                mps_args[1] = mps_args[1].cpu()
12252
12253            with self.assertRaisesRegex(error_type, error_regex):
12254                op(*mps_args, **mps_kwargs)
12255
12256class TestComplex(TestCase):
12257    def test_tensor_scalar_binops(self):
12258        # Regression test for https://github.com/pytorch/pytorch/issues/119088
12259        def to_cpu(x):
12260            return x.cpu() if isinstance(x, torch.Tensor) else x
12261
12262        # Allocate tensors on mps
12263        with torch.device("mps"):
12264            inputs = [torch.rand(2, dtype=dtype) for dtype in [torch.float, torch.half, torch.cfloat]]
12265        self.assertTrue(all(x.device.type == "mps" for x in inputs))
12266        # Add scalars
12267        inputs.extend([7, 3.14, 2 + 3j, torch.tensor(4 + 5j, dtype=torch.chalf)])
12268
12269        # Iterate over all permutations of types(int, float, complex, half) and ops (excluding div)
12270        for x, y in itertools.product(inputs, inputs):
12271            for op_name in ["__add__", "__sub__", "__mul__"]:
12272                x_cpu, y_cpu = map(to_cpu, (x, y))
12273                res = getattr(x, op_name)(y)
12274                res_cpu = getattr(x_cpu, op_name)(y_cpu)
12275                self.assertEqual(to_cpu(res), res_cpu, f"{op_name}({x}, {y}) produces different results {res} vs {res_cpu}")
12276
12277
12278# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
12279@skipIfSlowGradcheckEnv
12280class TestCommon(TestCase):
12281    exact_dtype = True
12282
12283    # Verifies, on teardown, that no OpInfo is still using dynamic dtypes in CI
12284    @classmethod
12285    def tearDownClass(cls):
12286        super().tearDownClass()
12287
12288        if IS_CI:
12289            err_msg = (
12290                "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
12291                "This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
12292            )
12293            # Assure no opinfo entry has dynamic_dtypes
12294            filtered_ops = list(filter(opinfo.utils.is_dynamic_dtype_set, op_db))
12295            for op in filtered_ops:
12296                fmt_str = opinfo.utils.str_format_dynamic_dtype(op)
12297                err_msg += "\n" + fmt_str
12298
12299            assert len(filtered_ops) == 0, err_msg
12300
12301    # This is the MPS equivalent of `test_numpy_ref` from `test_ops.py`. It lives over here while
12302    # MPS still requires some fairly heavy special casing in the test framework.
12303    # When MPS becomes more consistent, this can probably be merged with that test using
12304    # `@dtypesIfMPS(torch.float32)`, but for now, the assertions themselves need to be loosened
12305    @suppress_warnings
12306    # MPS only supports float32
12307    @ops(_ref_test_ops, allowed_dtypes=(torch.float32,))
12308    def test_numpy_ref_mps(self, device, dtype, op):
12309        # Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
12310        # does not support float64 Tensors.
12311        # A few ops are currently broken on their reference inputs, but not their sample inputs. These should
12312        # get patched up and this workaround removed.
12313        broken_on_ref_inputs = op.name in ('where',)
12314
12315        # TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12316        inputs = (
12317            op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs
12318            else op.sample_inputs(device, dtype, set_seed=False)
12319        )
12320        for sample_input in inputs:
12321            self.compare_with_reference(op, op.ref, sample_input)
12322
12323    @dtypes(*get_all_dtypes())
12324    def test_tensor_creation(self, device, dtype):
12325        def ones(device):
12326            return torch.ones((2, 2), dtype=dtype, device=device)
12327        if dtype not in MPS_DTYPES + ([torch.bfloat16, torch.complex64] if product_version > 14.0 else [torch.complex64]):
12328            with self.assertRaises(TypeError):
12329                ones(device)
12330        else:
12331            mps_tensor = ones(device)
12332            cpu_tensor = ones("cpu")
12333            self.assertEqual(mps_tensor.cpu(), cpu_tensor)
12334
12335
12336# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
12337# This requires mps to be properly registered in the device generic test framework which is not the
12338# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
12339# to achieve this.
12340instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
12341instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
12342instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
12343instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
12344instantiate_parametrized_tests(TestMPS)
12345
12346if __name__ == "__main__":
12347    run_tests()
12348