xref: /aosp_15_r20/external/executorch/backends/apple/mps/test/test_mps.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import inspect
7import logging
8import random
9import unittest
10from enum import Enum
11
12import torch
13from executorch.backends.apple.mps.test.test_mps_models import MPS_MODEL_NAME_TO_MODEL
14from executorch.backends.apple.mps.test.test_mps_utils import (
15    OpSequencesAddConv2d,
16    randomize_bn,
17    TestMPS,
18)
19from executorch.examples.models import MODEL_NAME_TO_MODEL
20from executorch.examples.models.model_factory import EagerModelFactory
21
22from executorch.exir.tests.models import (
23    BasicSinMax,
24    CompositeDelegateModule,
25    ElementwiseAdd,
26    Emformer,
27    MLP,
28    ModelWithUnusedArg,
29    Mul,
30    Repeat,
31)
32
33
34class MODEL_TYPE(Enum):
35    EXIR_DEFAULT_MODEL = 0
36    EXIR_TEST_MODEL = 1
37    MPS_TEST_MODEL = 2
38
39
40EXIR_MODEL_NAME_TO_MODEL = {
41    "repeat": lambda: (Repeat(), Repeat().get_random_inputs()),
42    "model_with_unused_arg": lambda: (
43        ModelWithUnusedArg(),
44        ModelWithUnusedArg().get_random_inputs(),
45    ),
46    "mlp": lambda: (MLP(), MLP().get_random_inputs()),
47    "mul_2": lambda: (Mul(), Mul().get_random_inputs()),
48    "element_wise_add": lambda: (
49        ElementwiseAdd(),
50        ElementwiseAdd().get_random_inputs(),
51    ),
52    "basic_sin_max": lambda: (BasicSinMax(), BasicSinMax().get_random_inputs()),
53    "composite_delegate_module": lambda: (
54        CompositeDelegateModule(),
55        CompositeDelegateModule().get_random_inputs(),
56    ),
57    "emformer": lambda: (Emformer(), Emformer().get_random_inputs()),
58}
59
60
61def run_model(
62    model: str,
63    model_type: MODEL_TYPE = MODEL_TYPE.EXIR_DEFAULT_MODEL,
64    use_fp16: bool = False,
65    lowering_func=None,
66):
67    logging.info(f"Step 1: Retrieving model: {model}...")
68    if model_type == MODEL_TYPE.EXIR_DEFAULT_MODEL:
69        m, m_inputs = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[model])
70    elif model_type == MODEL_TYPE.EXIR_TEST_MODEL:
71        m, m_inputs = EXIR_MODEL_NAME_TO_MODEL.get(model)()
72    elif model_type == MODEL_TYPE.MPS_TEST_MODEL:
73        m, m_inputs = MPS_MODEL_NAME_TO_MODEL.get(model)()
74
75    lowering_func(m, m_inputs, model)
76
77
78class TestMPSBackendExirModels(TestMPS):
79    def test_model_with_unused_arg(self):
80        run_model(
81            inspect.stack()[0].function[5:],
82            MODEL_TYPE.EXIR_TEST_MODEL,
83            lowering_func=self.lower_and_test_with_partitioner,
84        )
85
86    def test_mlp(self):
87        run_model(
88            inspect.stack()[0].function[5:],
89            MODEL_TYPE.EXIR_TEST_MODEL,
90            lowering_func=self.lower_and_test_with_partitioner,
91        )
92
93    def test_mul_2(self):
94        run_model(
95            inspect.stack()[0].function[5:],
96            MODEL_TYPE.EXIR_TEST_MODEL,
97            lowering_func=self.lower_and_test_with_partitioner,
98        )
99
100    def test_element_wise_add(self):
101        run_model(
102            inspect.stack()[0].function[5:],
103            MODEL_TYPE.EXIR_TEST_MODEL,
104            lowering_func=self.lower_and_test_with_partitioner,
105        )
106
107    def test_emformer(self):
108        run_model(
109            inspect.stack()[0].function[5:],
110            MODEL_TYPE.EXIR_TEST_MODEL,
111            lowering_func=self.lower_and_test_with_partitioner,
112        )
113
114
115class TestMPSBackendMPSModels(TestMPS):
116    def test_conv2D(self):
117        run_model(
118            inspect.stack()[0].function[5:],
119            MODEL_TYPE.MPS_TEST_MODEL,
120            lowering_func=self.lower_and_test_with_partitioner,
121        )
122
123    def test_norm(self):
124        run_model(
125            inspect.stack()[0].function[5:],
126            MODEL_TYPE.MPS_TEST_MODEL,
127            lowering_func=self.lower_and_test_with_partitioner,
128        )
129
130    def test_module_add(self):
131        run_model(
132            inspect.stack()[0].function[5:],
133            MODEL_TYPE.MPS_TEST_MODEL,
134            lowering_func=self.lower_and_test_with_partitioner,
135        )
136
137    def test_toy_model_for_mem_planning(self):
138        run_model(
139            inspect.stack()[0].function[5:],
140            MODEL_TYPE.MPS_TEST_MODEL,
141            lowering_func=self.lower_and_test_with_partitioner,
142        )
143
144    def test_mem_planning_with_scratch_tensor(self):
145        run_model(
146            inspect.stack()[0].function[5:],
147            MODEL_TYPE.MPS_TEST_MODEL,
148            lowering_func=self.lower_and_test_with_partitioner,
149        )
150
151    def test_module_ops_return_tensor_list(self):
152        run_model(
153            inspect.stack()[0].function[5:],
154            MODEL_TYPE.MPS_TEST_MODEL,
155            lowering_func=self.lower_and_test_with_partitioner,
156        )
157
158    def test_module_contiguous_tensor(self):
159        run_model(
160            inspect.stack()[0].function[5:],
161            MODEL_TYPE.MPS_TEST_MODEL,
162            lowering_func=self.lower_and_test_with_partitioner,
163        )
164
165    def test_module_input_dynamic_shape(self):
166        run_model(
167            inspect.stack()[0].function[5:],
168            MODEL_TYPE.MPS_TEST_MODEL,
169            lowering_func=self.lower_and_test_with_partitioner,
170        )
171
172
173class TestMPSUnitOpTesting(TestMPS):
174    def test_mps_backend_split_copy(self):
175        class SplitCopy(torch.nn.Module):
176            def __init__(self):
177                super().__init__()
178
179            def forward(self, x):
180                return torch.split(x, 2, 1)
181
182        example_inputs = (torch.randn(3, 5, 4, 7),)
183        self.lower_and_test_with_partitioner(
184            SplitCopy(), example_inputs, func_name=inspect.stack()[0].function[5:]
185        )
186
187    def test_mps_backend_unbind_copy(self):
188        class UnbindCopy(torch.nn.Module):
189            def __init__(self):
190                super().__init__()
191
192            def forward(self, x):
193                return torch.unbind(x, 1)
194
195        example_inputs = (torch.randn(3, 5, 4, 7),)
196        self.lower_and_test_with_partitioner(
197            UnbindCopy(), example_inputs, func_name=inspect.stack()[0].function[5:]
198        )
199
200    def test_mps_backend_pixel_shuffle(self):
201        class PixelShuffle(torch.nn.Module):
202            def __init__(self):
203                super().__init__()
204
205            def forward(self, x):
206                return torch.pixel_shuffle(x, 2)
207
208        example_inputs = (torch.randn(3, 8, 4, 7),)
209        self.lower_and_test_with_partitioner(
210            PixelShuffle(), example_inputs, func_name=inspect.stack()[0].function[5:]
211        )
212
213    def test_mps_backend_cumsum(self):
214        class CumulativeSum(torch.nn.Module):
215            def __init__(self):
216                super().__init__()
217
218            def forward(self, *x):
219                return torch.cumsum(x[0], dim=0)
220
221        example_inputs = (torch.randn(3, 5, 4, 7),)
222        self.lower_and_test_with_partitioner(
223            CumulativeSum(), example_inputs, func_name=inspect.stack()[0].function[5:]
224        )
225
226    def test_mps_backend_stack(self):
227        class Stack(torch.nn.Module):
228            def __init__(self):
229                super().__init__()
230
231            def forward(self, *x):
232                return torch.stack((x), 0)
233
234        example_inputs = (
235            torch.randn(1, 5, 1, 8),
236            torch.randn(1, 5, 1, 8),
237        )
238        self.lower_and_test_with_partitioner(
239            Stack(), example_inputs, func_name=inspect.stack()[0].function[5:]
240        )
241
242    def test_mps_backend_cat(self):
243        class Cat(torch.nn.Module):
244            def __init__(self):
245                super().__init__()
246
247            def forward(self, *x):
248                return torch.cat((x), 1)
249
250        example_inputs = (
251            torch.randn(1, 5, 1, 8),
252            torch.randn(1, 5, 1, 8),
253        )
254        self.lower_and_test_with_partitioner(
255            Cat(), example_inputs, func_name=inspect.stack()[0].function[5:]
256        )
257
258    def test_mps_backend_expand_copy(self):
259        class ExpandCopy(torch.nn.Module):
260            def __init__(self):
261                super().__init__()
262                self.example_inputs = [7, 5, 4, 8]
263
264            def forward(self, x):
265                return x.expand(self.example_inputs)
266
267        example_inputs = (torch.randn(1, 5, 1, 8),)
268        self.lower_and_test_with_partitioner(
269            ExpandCopy(), example_inputs, func_name=inspect.stack()[0].function[5:]
270        )
271
272    def test_mps_backend_select(self):
273        class Select(torch.nn.Module):
274            def __init__(self):
275                super().__init__()
276
277            def forward(self, x):
278                return torch.select(x, 3, 2)
279
280        example_inputs = (torch.randn(3, 5, 4, 7),)
281        self.lower_and_test_with_partitioner(
282            Select(), example_inputs, func_name=inspect.stack()[0].function[5:]
283        )
284
285    def test_mps_backend_view_copy(self):
286        class ViewCopy(torch.nn.Module):
287            def __init__(self):
288                super().__init__()
289                self.example_inputs = [2, 10, 2, 4]
290
291            def forward(self, x):
292                return x.view(self.example_inputs)
293
294        example_inputs = (torch.randn(1, 5, 4, 8),)
295        self.lower_and_test_with_partitioner(
296            ViewCopy(), example_inputs, func_name=inspect.stack()[0].function[5:]
297        )
298
299    def test_mps_backend_mean_dim_2(self):
300        class Mean(torch.nn.Module):
301            def __init__(self):
302                super().__init__()
303
304            def forward(self, x):
305                return torch.mean(x, (-1, -2), keepdim=True)
306
307        example_inputs = (torch.randn(1, 5, 4, 4),)
308        self.lower_and_test_with_partitioner(
309            Mean(), example_inputs, func_name=inspect.stack()[0].function[5:]
310        )
311
312    def test_mps_backend_squeeze_dim_1(self):
313        class Squeeze(torch.nn.Module):
314            def __init__(self):
315                super().__init__()
316
317            def forward(self, x):
318                y = torch.squeeze(x, 2)
319                return torch.squeeze(y, 0)
320
321        example_inputs = (torch.randn(1, 5, 1, 1, 4),)
322        self.lower_and_test_with_partitioner(
323            Squeeze(), example_inputs, func_name=inspect.stack()[0].function[5:]
324        )
325
326    def test_mps_backend_unsqueeze_dim_1(self):
327        class Squeeze(torch.nn.Module):
328            def __init__(self):
329                super().__init__()
330
331            def forward(self, x):
332                return torch.unsqueeze(x, 1)
333
334        example_inputs = (torch.randn(1, 5, 1, 4),)
335        self.lower_and_test_with_partitioner(
336            Squeeze(), example_inputs, func_name=inspect.stack()[0].function[5:]
337        )
338
339    def test_mps_backend_mean_dim_no_keepdim(self):
340        class Mean(torch.nn.Module):
341            def __init__(self):
342                super().__init__()
343
344            def forward(self, x):
345                return torch.mean(x, (-1, -2), keepdim=False)
346
347        example_inputs = (torch.randn(1, 5, 4, 4),)
348        self.lower_and_test_with_partitioner(
349            Mean(), example_inputs, func_name=inspect.stack()[0].function[5:]
350        )
351
352    def test_mps_backend_mean_dim_unsupported(self):
353        class Mean(torch.nn.Module):
354            def __init__(self):
355                super().__init__()
356
357            def forward(self, x):
358                return torch.mean(x, (3), keepdim=True)
359
360        example_inputs = (torch.randn(1, 5, 4, 4),)
361        self.lower_and_test_with_partitioner(
362            Mean(), example_inputs, func_name=inspect.stack()[0].function[5:]
363        )
364
365    def test_mps_backend_static_transpose(self):
366        class PermuteModule(torch.nn.Module):
367            def __init__(self):
368                super().__init__()
369                self.nchw_to_nhwc = [0, 2, 3, 1]
370
371            def forward(self, x):
372                return torch.permute(x, self.nchw_to_nhwc)
373
374        example_inputs = (torch.randn(1, 1, 4, 4),)
375        self.lower_module_and_test_output(
376            PermuteModule(), example_inputs, func_name=inspect.stack()[0].function[5:]
377        )
378
379    def test_mps_backend_sequential_conv2d(self):
380        class TwoConv(torch.nn.Module):
381            def __init__(self):
382                super().__init__()
383                self.first = torch.nn.Conv2d(
384                    in_channels=1,
385                    out_channels=3,
386                    kernel_size=(3, 3),
387                    padding=1,
388                    bias=False,
389                )
390                self.second = torch.nn.Conv2d(
391                    in_channels=3,
392                    out_channels=2,
393                    kernel_size=(3, 3),
394                    padding=1,
395                    bias=False,
396                )
397
398            def forward(self, x):
399                return self.second(self.first(x))
400
401        example_inputs = (torch.randn(1, 1, 3, 3),)
402        self.lower_and_test_with_partitioner(
403            TwoConv(), example_inputs, func_name=inspect.stack()[0].function[5:]
404        )
405
406    def test_mps_backend_conv2d_bn_1(self):
407        class ModelConvBN(torch.nn.Module):
408            def __init__(self, in_features: int, out_features: int, kernel_size):
409                super().__init__()
410                self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
411                self.bn = randomize_bn(out_features)
412
413            def forward(self, x):
414                y = self.conv2d(x)
415                y = self.bn(y)
416                return y
417
418        model = ModelConvBN(2, 2, (2, 2)).eval()
419
420        self.lower_and_test_with_partitioner(
421            model, (torch.randn(2, 2, 4, 4),), func_name=inspect.stack()[0].function[5:]
422        )
423
424    def test_mps_backend_conv2d(self):
425        groups = 1
426        stride = [2, 2]
427        padding = [1, 1]
428        dilation = [1, 1]
429        in_channels = 2
430        out_channels = 1
431        width = 8
432        height = 8
433        batches = 1
434        example_inputs = (torch.randn(batches, in_channels, height, width),)
435        conv = torch.nn.Conv2d(
436            in_channels=in_channels,
437            out_channels=out_channels,
438            kernel_size=(3, 3),
439            stride=stride,
440            padding=padding,
441            groups=groups,
442            dilation=dilation,
443            bias=True,
444        )
445        conv.eval()
446        self.lower_and_test_with_partitioner(
447            conv, example_inputs, func_name=inspect.stack()[0].function[5:]
448        )
449
450    def test_conv1d(self):
451        example_inputs = (torch.randn(1, 57, 40),)
452        stride = random.randint(1, 4)
453        padding = random.randint(1, 4)
454        conv = torch.nn.Conv1d(
455            57,
456            20,
457            stride=stride,
458            padding=padding,
459            kernel_size=3,
460            bias=random.choice([True, False]),
461        )
462        conv.eval()
463        self.lower_and_test_with_partitioner(
464            conv, example_inputs, func_name=inspect.stack()[0].function[5:]
465        )
466
467    def test_conv2d_simple(self):
468        N = 10
469        C = 10
470        H = 4
471        W = 6
472        groups = 2
473        input_memory_format = torch.contiguous_format
474        weight_memory_format = torch.contiguous_format
475        strideX = random.randint(1, 4)
476        strideY = random.randint(1, 4)
477        example_inputs = (
478            torch.randn(N, C, H, W).to(memory_format=input_memory_format),
479        )
480        conv = torch.nn.Conv2d(
481            in_channels=N,
482            out_channels=C,
483            kernel_size=H,
484            groups=groups,
485            stride=(strideX, strideY),
486            bias=False,
487        )
488        conv.weight.data = conv.weight.to(memory_format=weight_memory_format)
489        conv.eval()
490        self.lower_and_test_with_partitioner(
491            conv, example_inputs, func_name=inspect.stack()[0].function[5:]
492        )
493
494    def test_conv2d_to_depthwise_conv_3d(self):
495        N = 10
496        C = 10
497        H = 4
498        W = 6
499        groups = 10
500        input_memory_format = torch.contiguous_format
501        weight_memory_format = torch.contiguous_format
502        strideX = random.randint(1, 4)
503        strideY = random.randint(1, 4)
504        example_inputs = (
505            torch.randn(N, C, H, W).to(memory_format=input_memory_format),
506        )
507        conv = torch.nn.Conv2d(
508            in_channels=N,
509            out_channels=C,
510            kernel_size=H,
511            groups=groups,
512            stride=(strideX, strideY),
513        )
514        conv.weight.data = conv.weight.to(memory_format=weight_memory_format)
515        conv.eval()
516        self.lower_and_test_with_partitioner(
517            conv, example_inputs, func_name=inspect.stack()[0].function[5:]
518        )
519
520    def test_mps_backend_conv2d_single_int_params(self):
521        groups = 1
522        stride = 2
523        padding = "valid"
524        dilation = 1
525        in_channels = 2
526        out_channels = 1
527        width = 8
528        height = 8
529        batches = 1
530        example_inputs = (torch.randn(batches, in_channels, height, width),)
531        conv = torch.nn.Conv2d(
532            in_channels=in_channels,
533            out_channels=out_channels,
534            kernel_size=3,
535            stride=stride,
536            padding=padding,
537            groups=groups,
538            dilation=dilation,
539            bias=True,
540        )
541        conv.eval()
542        self.lower_and_test_with_partitioner(
543            conv, example_inputs, func_name=inspect.stack()[0].function[5:]
544        )
545
546    def test_mps_backend_conv2d_dw(self):
547        # Depthwise Convolution Requirements:
548        # - Groups must equal In Channels
549        # - Out Channels must be a positive multiple of In Channels
550        groups = 2
551        stride = [2, 2]
552        padding = [1, 1]
553        dilation = [1, 1]
554        in_channels = groups
555        out_channels = 3 * in_channels
556        width = 8
557        height = 8
558        batches = 1
559        example_inputs = (torch.randn(batches, in_channels, height, width),)
560        conv = torch.nn.Conv2d(
561            in_channels=in_channels,
562            out_channels=out_channels,
563            kernel_size=(3, 3),
564            stride=stride,
565            padding=padding,
566            groups=groups,
567            dilation=dilation,
568            bias=True,
569        )
570        conv.eval()
571        self.lower_and_test_with_partitioner(
572            conv, example_inputs, func_name=inspect.stack()[0].function[5:]
573        )
574
575    def test_mps_backend_mm(self):
576        in_sizes = [1, 4, 4]
577        input_sizes = [4, 37, 17]
578        output_sizes = [4, 17, 37]
579        for i, _ in enumerate(in_sizes):
580            in_size = int(in_sizes[i])
581            input_size = int(input_sizes[i])
582            output_size = int(output_sizes[i])
583            linear = torch.nn.Linear(input_size, output_size, bias=False).eval()
584            example_input = (torch.randn(in_size, input_size),)
585
586            self.lower_and_test_with_partitioner(
587                linear, example_input, func_name=inspect.stack()[0].function[5:]
588            )
589
590    def test_mps_backend_bmm(self):
591        class BmmModule(torch.nn.Module):
592            def __init__(
593                self,
594            ):
595                super().__init__()
596                self.bmm = torch.bmm
597
598            def forward(self, x, y):
599                return self.bmm(x, y)
600
601        mul_module = BmmModule()
602        model_inputs = (
603            torch.randn((3, 1, 8)),
604            torch.randn((3, 8, 1)),
605        )
606
607        self.lower_and_test_with_partitioner(
608            mul_module, model_inputs, func_name=inspect.stack()[0].function[5:]
609        )
610
611    def test_mps_backend_addmm(self):
612        in_sizes = [1, 4, 4]
613        input_sizes = [4, 37, 17]
614        output_sizes = [4, 17, 37]
615        for i, _ in enumerate(in_sizes):
616            in_size = int(in_sizes[i])
617            input_size = int(input_sizes[i])
618            output_size = int(output_sizes[i])
619            linear = torch.nn.Linear(input_size, output_size, bias=True).eval()
620            example_input = (torch.randn(in_size, input_size),)
621
622            self.lower_and_test_with_partitioner(
623                linear, example_input, func_name=inspect.stack()[0].function[5:]
624            )
625
626    def test_mps_backend_full_ones_default(self):
627        class Ones(torch.nn.Module):
628            def __init__(self):
629                super().__init__()
630
631            def forward(self):
632                size = (4, 37, 17)
633                return torch.ones(size)
634
635        self.lower_and_test_with_partitioner(
636            Ones(), (), func_name=inspect.stack()[0].function[5:]
637        )
638
639    def test_mps_backend_full_zeros_default(self):
640        class Zeros(torch.nn.Module):
641            def __init__(self):
642                super().__init__()
643
644            def forward(self):
645                size = (4, 37, 17)
646                return torch.zeros(size=size)
647
648        self.lower_and_test_with_partitioner(
649            Zeros(), (), func_name=inspect.stack()[0].function[5:]
650        )
651
652    def test_mps_backend_full_default(self):
653        class Full(torch.nn.Module):
654            def __init__(self):
655                super().__init__()
656
657            def forward(self):
658                size = (4, 37, 17)
659                return torch.full(size=size, fill_value=2.0)
660
661        self.lower_and_test_with_partitioner(
662            Full(), (), func_name=inspect.stack()[0].function[5:]
663        )
664
665    def test_mps_backend_full_like(self):
666        class Full_Like(torch.nn.Module):
667            def __init__(self):
668                super().__init__()
669
670            def forward(self, x):
671                return torch.full_like(x, fill_value=2.0)
672
673        const_module = Full_Like()
674        model_inputs = (torch.randn(4, 37, 17),)
675
676        self.lower_and_test_with_partitioner(
677            const_module, model_inputs, func_name=inspect.stack()[0].function[5:]
678        )
679
680    def test_mps_backend_logit_1(self):
681        class LogitModule(torch.nn.Module):
682            def __init__(self):
683                super().__init__()
684
685            def forward(self, x):
686                z = torch.ops.aten.logit.default(x)
687                return z
688
689        logit_module = LogitModule()
690        model_inputs = (torch.rand(5),)
691
692        self.lower_and_test_with_partitioner(
693            logit_module, model_inputs, func_name=inspect.stack()[0].function[5:]
694        )
695
696    def test_mps_backend_logit_2(self):
697        class LogitModule(torch.nn.Module):
698            def __init__(self):
699                super().__init__()
700
701            def forward(self, x):
702                z = torch.ops.aten.logit.default(x, eps=1e-6)
703                return z
704
705        logit_module = LogitModule()
706        model_inputs = (torch.rand(5),)
707
708        self.lower_and_test_with_partitioner(
709            logit_module, model_inputs, func_name=inspect.stack()[0].function[5:]
710        )
711
712    def test_mps_backend_round(self):
713        class RoundModule(torch.nn.Module):
714            def __init__(self):
715                super().__init__()
716
717            def forward(self, x):
718                out = torch.round(x)
719                return out
720
721        module = RoundModule()
722        model_inputs = (torch.randn(5, 2),)
723
724        self.lower_and_test_with_partitioner(
725            module, model_inputs, func_name=inspect.stack()[0].function[5:]
726        )
727
728    def test_mps_backend_amax(self):
729        class AmaxModule(torch.nn.Module):
730            def __init__(self):
731                super().__init__()
732
733            def forward(self, x):
734                out = torch.amax(x, 1)
735                return out
736
737        module = AmaxModule()
738        model_inputs = (torch.randn(2, 3, 4),)
739
740        self.lower_and_test_with_partitioner(
741            module, model_inputs, func_name=inspect.stack()[0].function[5:]
742        )
743
744    def test_mps_backend_amin(self):
745        class AminModule(torch.nn.Module):
746            def __init__(self):
747                super().__init__()
748
749            def forward(self, x):
750                out = torch.amin(x, 1)
751                return out
752
753        module = AminModule()
754        model_inputs = (torch.randn(2, 3, 4),)
755
756        self.lower_and_test_with_partitioner(
757            module, model_inputs, func_name=inspect.stack()[0].function[5:]
758        )
759
760    @unittest.skip
761    def test_mps_backend_min_dim(self):
762        class MinModule(torch.nn.Module):
763            def __init__(self):
764                super().__init__()
765
766            def forward(self, x):
767                out = torch.min(x, 1)
768                return out
769
770        module = MinModule()
771        model_inputs = (torch.randn(2, 3, 4),)
772
773        self.lower_and_test_with_partitioner(
774            module, model_inputs, func_name=inspect.stack()[0].function[5:]
775        )
776
777    def test_mps_backend_argmax_1(self):
778        class ArgmaxModule(torch.nn.Module):
779            def __init__(self):
780                super().__init__()
781
782            def forward(self, x):
783                out1 = torch.argmax(x, 1)
784                return out1
785
786        module = ArgmaxModule()
787        model_inputs = (torch.randn(5, 10),)
788
789        self.lower_and_test_with_partitioner(
790            module, model_inputs, func_name=inspect.stack()[0].function[5:]
791        )
792
793    def test_mps_backend_argmax_2(self):
794        class ArgmaxModule(torch.nn.Module):
795            def __init__(self):
796                super().__init__()
797
798            def forward(self, x):
799                out1 = torch.argmax(x)
800                return out1
801
802        module = ArgmaxModule()
803        model_inputs = (torch.randn(5, 10),)
804
805        self.lower_and_test_with_partitioner(
806            module, model_inputs, func_name=inspect.stack()[0].function[5:]
807        )
808
809    def test_mps_backend_argmin_1(self):
810        class ArgminModule(torch.nn.Module):
811            def __init__(self):
812                super().__init__()
813
814            def forward(self, x):
815                out1 = torch.argmin(x, 1)
816                return out1
817
818        module = ArgminModule()
819        model_inputs = (torch.randn(5, 10),)
820
821        self.lower_and_test_with_partitioner(
822            module, model_inputs, func_name=inspect.stack()[0].function[5:]
823        )
824
825    def test_mps_backend_argmin_2(self):
826        class ArgminModule(torch.nn.Module):
827            def __init__(self):
828                super().__init__()
829
830            def forward(self, x):
831                out1 = torch.argmin(x)
832                return out1
833
834        module = ArgminModule()
835        model_inputs = (torch.randn(5, 10),)
836
837        self.lower_and_test_with_partitioner(
838            module, model_inputs, func_name=inspect.stack()[0].function[5:]
839        )
840
841    def test_mps_backend_minimum(self):
842        class MinimumModule(torch.nn.Module):
843            def __init__(
844                self,
845            ):
846                super().__init__()
847                self.minimum_module = torch.minimum
848
849            def forward(self, x, y):
850                return self.minimum_module(x, y)
851
852        module = MinimumModule()
853        model_inputs = (
854            torch.randn(1, 3, 6),
855            torch.randn(1, 3, 6),
856        )
857        self.lower_and_test_with_partitioner(
858            module, model_inputs, func_name=inspect.stack()[0].function[5:]
859        )
860
861    def test_mps_backend_eq_tensor_1(self):
862        class EqModule(torch.nn.Module):
863            def __init__(self):
864                super().__init__()
865
866            def forward(self, x, y):
867                out = torch.eq(x, y)
868                return out
869
870        module = EqModule()
871        model_inputs = (
872            torch.randn(2, 3, 4),
873            torch.randn(2, 3, 4),
874        )
875
876        self.lower_and_test_with_partitioner(
877            module, model_inputs, func_name=inspect.stack()[0].function[5:]
878        )
879
880    def test_mps_backend_eq_tensor_2(self):
881        class EqModule(torch.nn.Module):
882            def __init__(self):
883                super().__init__()
884
885            def forward(self, x, y):
886                out = torch.eq(x, y)
887                return out
888
889        module = EqModule()
890        input_tensor = torch.randn(2, 3, 4)
891        model_inputs = (input_tensor, input_tensor)
892
893        self.lower_and_test_with_partitioner(
894            module, model_inputs, func_name=inspect.stack()[0].function[5:]
895        )
896
897    def test_mps_backend_eq_scalar(self):
898        class EqModule(torch.nn.Module):
899            def __init__(self):
900                super().__init__()
901
902            def forward(self, x):
903                out = torch.eq(x, 1.0)
904                return out
905
906        module = EqModule()
907        model_inputs = (torch.randn(2, 3, 4),)
908
909        self.lower_and_test_with_partitioner(
910            module, model_inputs, func_name=inspect.stack()[0].function[5:]
911        )
912
913    def test_mps_backend_ne_tensor_1(self):
914        class NeModule(torch.nn.Module):
915            def __init__(self):
916                super().__init__()
917
918            def forward(self, x, y):
919                out = torch.ne(x, y)
920                return out
921
922        module = NeModule()
923        model_inputs = (
924            torch.randn(2, 3, 4),
925            torch.randn(2, 3, 4),
926        )
927
928        self.lower_and_test_with_partitioner(
929            module, model_inputs, func_name=inspect.stack()[0].function[5:]
930        )
931
932    def test_mps_backend_ne_tensor_2(self):
933        class NeModule(torch.nn.Module):
934            def __init__(self):
935                super().__init__()
936
937            def forward(self, x, y):
938                out = torch.ne(x, y)
939                return out
940
941        module = NeModule()
942        input_tensor = torch.randn(2, 3, 4)
943        model_inputs = (input_tensor, input_tensor)
944
945        self.lower_and_test_with_partitioner(
946            module, model_inputs, func_name=inspect.stack()[0].function[5:]
947        )
948
949    def test_mps_backend_ne_scalar(self):
950        class NeModule(torch.nn.Module):
951            def __init__(self):
952                super().__init__()
953
954            def forward(self, x):
955                out = torch.ne(x, 1.0)
956                return out
957
958        module = NeModule()
959        model_inputs = (torch.randn(2, 3, 4),)
960
961        self.lower_and_test_with_partitioner(
962            module, model_inputs, func_name=inspect.stack()[0].function[5:]
963        )
964
965    def test_mps_backend_ge_tensor_1(self):
966        class GeModule(torch.nn.Module):
967            def __init__(self):
968                super().__init__()
969
970            def forward(self, x, y):
971                out = torch.ge(x, y)
972                return out
973
974        module = GeModule()
975        model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4))
976
977        self.lower_and_test_with_partitioner(
978            module, model_inputs, func_name=inspect.stack()[0].function[5:]
979        )
980
981    def test_mps_backend_ge_tensor_2(self):
982        class GeModule(torch.nn.Module):
983            def __init__(self):
984                super().__init__()
985
986            def forward(self, x, y):
987                out = torch.ge(x, y)
988                return out
989
990        module = GeModule()
991
992        input_tensor = torch.randn(2, 3, 4)
993        model_inputs = (input_tensor, input_tensor)
994
995        self.lower_and_test_with_partitioner(
996            module, model_inputs, func_name=inspect.stack()[0].function[5:]
997        )
998
999    def test_mps_backend_ge_scalar(self):
1000        class GeModule(torch.nn.Module):
1001            def __init__(self):
1002                super().__init__()
1003
1004            def forward(self, x):
1005                out = torch.ge(x, 1.0)
1006                return out
1007
1008        module = GeModule()
1009        model_inputs = (torch.randn(2, 3, 4),)
1010
1011        self.lower_and_test_with_partitioner(
1012            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1013        )
1014
1015    def test_mps_backend_gt_tensor_1(self):
1016        class GtModule(torch.nn.Module):
1017            def __init__(self):
1018                super().__init__()
1019
1020            def forward(self, x, y):
1021                out = torch.gt(x, y)
1022                return out
1023
1024        module = GtModule()
1025        model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4))
1026
1027        self.lower_and_test_with_partitioner(
1028            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1029        )
1030
1031    def test_mps_backend_gt_tensor_2(self):
1032        class GtModule(torch.nn.Module):
1033            def __init__(self):
1034                super().__init__()
1035
1036            def forward(self, x, y):
1037                out = torch.gt(x, y)
1038                return out
1039
1040        module = GtModule()
1041        input_tensor = torch.randn(2, 3, 4)
1042        model_inputs = (input_tensor, input_tensor)
1043
1044        self.lower_and_test_with_partitioner(
1045            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1046        )
1047
1048    def test_mps_backend_gt_scalar(self):
1049        class GtModule(torch.nn.Module):
1050            def __init__(self):
1051                super().__init__()
1052
1053            def forward(self, x):
1054                out = torch.gt(x, 1.0)
1055                return out
1056
1057        module = GtModule()
1058        model_inputs = (torch.randn(2, 3, 4),)
1059
1060        self.lower_and_test_with_partitioner(
1061            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1062        )
1063
1064    def test_mps_backend_isnan(self):
1065        class IsNanModule(torch.nn.Module):
1066            def __init__(self):
1067                super().__init__()
1068
1069            def forward(self, x):
1070                return torch.isnan(x)
1071
1072        module = IsNanModule()
1073        model_inputs = (
1074            torch.randn(8, 3, 4, 5).index_put_(
1075                indices=[torch.tensor([random.randrange(0, 8)])],
1076                values=torch.tensor(float("nan")),
1077            ),
1078        )
1079        self.lower_and_test_with_partitioner(
1080            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1081        )
1082
1083    def test_mps_backend_partitioner(self):
1084        # `index.Tensor`` is not yet natively supported
1085        # It will fall back to MPSPartitioner. Once implemented,
1086        # replace the op with an unsupported one.
1087        class IndexTensorModule(torch.nn.Module):
1088            def __init__(self):
1089                super().__init__()
1090                self.indices = torch.tensor([0, 5, 2, 3])
1091
1092            def forward(self, x):
1093                y = torch.add(x, 2.0)
1094                z = y[self.indices]
1095                r = z + x[self.indices]
1096                d = r - 2
1097                p = torch.pow(d, 4)
1098                return p / 10
1099
1100        module = IndexTensorModule()
1101
1102        model_inputs = (torch.randn(8, 3, 4, 5),)
1103        self.lower_and_test_with_partitioner(
1104            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1105        )
1106
1107    def test_mps_indexing_get_1(self):
1108        class IndexGet(torch.nn.Module):
1109            def __init__(self):
1110                super().__init__()
1111
1112            def forward(self, x):
1113                return x[[0, 1, 2], [0, 1, 0]]
1114
1115        module = IndexGet()
1116        model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),)
1117
1118        self.lower_and_test_with_partitioner(
1119            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1120        )
1121
1122    def test_mps_indexing_get_2(self):
1123        class IndexGet(torch.nn.Module):
1124            def __init__(self):
1125                super().__init__()
1126
1127            def forward(self, x):
1128                return x[:, [0, 4, 2]]
1129
1130        module = IndexGet()
1131        model_inputs = (torch.randn(5, 7, 3),)
1132
1133        self.lower_and_test_with_partitioner(
1134            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1135        )
1136
1137    def test_mps_indexing_get_3(self):
1138        class IndexGet(torch.nn.Module):
1139            def __init__(self):
1140                super().__init__()
1141
1142            def forward(self, x):
1143                return x[:, [[0, 1], [4, 3]]]
1144
1145        module = IndexGet()
1146        model_inputs = (torch.randn(5, 7, 3),)
1147
1148        self.lower_and_test_with_partitioner(
1149            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1150        )
1151
1152    def test_mps_indexing_get_4(self):
1153        class IndexGet(torch.nn.Module):
1154            def __init__(self):
1155                super().__init__()
1156
1157            def forward(self, x):
1158                return x[[0, 4, 2]]
1159
1160        module = IndexGet()
1161        model_inputs = (torch.randn(5, 7, 3),)
1162
1163        self.lower_and_test_with_partitioner(
1164            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1165        )
1166
1167    def test_mps_indexing_get_5(self):
1168        class IndexGet(torch.nn.Module):
1169            def __init__(self):
1170                super().__init__()
1171
1172            def forward(self, x):
1173                return x[[0, 2, 1], :, 0]
1174
1175        module = IndexGet()
1176        model_inputs = (torch.ones(3, 2, 4),)
1177
1178        self.lower_and_test_with_partitioner(
1179            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1180        )
1181
1182    def test_mps_indices2d(self):
1183        class IndexGet(torch.nn.Module):
1184            def __init__(self):
1185                super().__init__()
1186
1187            def forward(self, x, rows, columns):
1188                return x[rows, columns]
1189
1190        module = IndexGet()
1191        x = torch.arange(0, 12).resize(4, 3)
1192        rows = torch.tensor([[0, 0], [3, 3]])
1193        columns = torch.tensor([[0, 2], [0, 2]])
1194        model_inputs = (
1195            x,
1196            rows,
1197            columns,
1198        )
1199
1200        self.lower_and_test_with_partitioner(
1201            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1202        )
1203
1204    def test_mps_slicing_using_advanced_index_for_column_0(self):
1205        class IndexGet(torch.nn.Module):
1206            def __init__(self):
1207                super().__init__()
1208
1209            def forward(self, x):
1210                return x[1:4]
1211
1212        module = IndexGet()
1213        model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),)
1214
1215        self.lower_and_test_with_partitioner(
1216            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1217        )
1218
1219    def test_mps_slicing_using_advanced_index_for_column_1(self):
1220        class IndexGet(torch.nn.Module):
1221            def __init__(self):
1222                super().__init__()
1223
1224            def forward(self, x):
1225                # using advanced index for column
1226                return x[1:4, [1, 2]]
1227
1228        module = IndexGet()
1229        model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),)
1230
1231        self.lower_and_test_with_partitioner(
1232            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1233        )
1234
1235    @unittest.skip
1236    def test_boolean_array_indexing(self):
1237        class IndexGet(torch.nn.Module):
1238            def __init__(self):
1239                super().__init__()
1240
1241            def forward(self, x):
1242                return x[x > 5]
1243
1244        module = IndexGet()
1245        model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),)
1246
1247        self.lower_and_test_with_partitioner(
1248            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1249        )
1250
1251    def test_mps_backend_isinf(self):
1252        class IsInfModule(torch.nn.Module):
1253            def __init__(self):
1254                super().__init__()
1255
1256            def forward(self, x):
1257                return torch.isinf(x)
1258
1259        module = IsInfModule()
1260        model_inputs = (
1261            torch.randn(8, 3, 4, 5).index_put_(
1262                indices=[torch.tensor([random.randrange(0, 8)])],
1263                values=torch.tensor(float("inf")),
1264            ),
1265        )
1266        self.lower_and_test_with_partitioner(
1267            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1268        )
1269
1270    def test_mps_backend_le_tensor_1(self):
1271        class LeModule(torch.nn.Module):
1272            def __init__(self):
1273                super().__init__()
1274
1275            def forward(self, x, y):
1276                out = torch.le(x, y)
1277                return out
1278
1279        module = LeModule()
1280        model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4))
1281
1282        self.lower_and_test_with_partitioner(
1283            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1284        )
1285
1286    def test_mps_backend_le_tensor_2(self):
1287        class LeModule(torch.nn.Module):
1288            def __init__(self):
1289                super().__init__()
1290
1291            def forward(self, x, y):
1292                out = torch.le(x, y)
1293                return out
1294
1295        module = LeModule()
1296        input_tensor = torch.randn(2, 3, 4)
1297        model_inputs = (input_tensor, input_tensor)
1298
1299        self.lower_and_test_with_partitioner(
1300            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1301        )
1302
1303    def test_mps_backend_le_scalar(self):
1304        class LeModule(torch.nn.Module):
1305            def __init__(self):
1306                super().__init__()
1307
1308            def forward(self, x):
1309                out = torch.le(x, 1.0)
1310                return out
1311
1312        module = LeModule()
1313        model_inputs = (torch.randn(2, 3, 4),)
1314
1315        self.lower_and_test_with_partitioner(
1316            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1317        )
1318
1319    def test_mps_backend_lt_tensor_1(self):
1320        class LtModule(torch.nn.Module):
1321            def __init__(self):
1322                super().__init__()
1323
1324            def forward(self, x, y):
1325                out = torch.lt(x, y)
1326                return out
1327
1328        module = LtModule()
1329        model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4))
1330
1331        self.lower_and_test_with_partitioner(
1332            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1333        )
1334
1335    def test_mps_backend_lt_tensor_2(self):
1336        class LtModule(torch.nn.Module):
1337            def __init__(self):
1338                super().__init__()
1339
1340            def forward(self, x, y):
1341                out = torch.le(x, y)
1342                return out
1343
1344        module = LtModule()
1345        input_tensor = torch.randn(2, 3, 4)
1346        model_inputs = (input_tensor, input_tensor)
1347
1348        self.lower_and_test_with_partitioner(
1349            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1350        )
1351
1352    def test_mps_backend_lt_scalar(self):
1353        class LtModule(torch.nn.Module):
1354            def __init__(self):
1355                super().__init__()
1356
1357            def forward(self, x):
1358                out = torch.lt(x, 1.0)
1359                return out
1360
1361        module = LtModule()
1362        model_inputs = (torch.randn(2, 3, 4),)
1363
1364        self.lower_and_test_with_partitioner(
1365            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1366        )
1367
1368    @torch.inference_mode()  # TODO Use  for capturing.
1369    def test_mps_backend_linear(self):
1370        in_size = 2
1371        input_size = 3
1372        output_size = 4
1373        linear = torch.nn.Linear(input_size, output_size).eval()
1374        example_input = (torch.randn(in_size, input_size),)
1375
1376        self.lower_and_test_with_partitioner(
1377            linear, example_input, func_name=inspect.stack()[0].function[5:]
1378        )
1379
1380    def test_mps_backend_glu(self):
1381        class GLUModule(torch.nn.Module):
1382            def __init__(self, dim):
1383                super().__init__()
1384                self.glu = torch.nn.GLU(dim=dim)
1385
1386            def forward(self, x):
1387                return self.glu(x)
1388
1389        shape = (4, 2)
1390        for dim in list(range(len(shape))) + [-1]:
1391            model_inputs = (torch.rand(shape),)
1392            glu_module = GLUModule(dim)
1393            self.lower_and_test_with_partitioner(
1394                glu_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1395            )
1396
1397    def test_mps_backend_softmax(self):
1398        class SoftMaxModule(torch.nn.Module):
1399            def __init__(self, dim):
1400                super().__init__()
1401                self.softmax = torch.nn.Softmax(dim=dim)
1402
1403            def forward(self, x):
1404                return self.softmax(x)
1405
1406        shape = (3, 5, 7)
1407        for dim in list(range(len(shape))) + [-1]:
1408            model_inputs = (torch.rand(shape),)
1409            softmax_module = SoftMaxModule(dim)
1410            self.lower_and_test_with_partitioner(
1411                softmax_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1412            )
1413
1414    def test_mps_backend_log_softmax(self):
1415        class LogSoftMaxModule(torch.nn.Module):
1416            def __init__(self, dim):
1417                super().__init__()
1418                self.logsoftmax = torch.nn.LogSoftmax(dim=dim)
1419
1420            def forward(self, x):
1421                return self.logsoftmax(x)
1422
1423        shape = (3, 5, 7)
1424        for dim in list(range(len(shape))) + [-1]:
1425            model_inputs = (torch.rand(shape),)
1426            logsoftmax_module = LogSoftMaxModule(dim)
1427
1428            self.lower_and_test_with_partitioner(
1429                logsoftmax_module,
1430                model_inputs,
1431                func_name=inspect.stack()[0].function[5:],
1432            )
1433
1434    def test_mps_backend_hardtanh(self):
1435        class HardTanhModule(torch.nn.Module):
1436            def __init__(self, min_val=-1.0, max_val=1.0):
1437                super().__init__()
1438                self.hardtanh = torch.nn.Hardtanh(min_val, max_val)
1439
1440            def forward(self, x):
1441                return self.hardtanh(x)
1442
1443        inputs = [torch.randn(2, 3, 4), torch.randn(7, 5, 2), torch.randn(2, 9)]
1444        for test_input in inputs:
1445            hardtanh_model = HardTanhModule()
1446            self.lower_and_test_with_partitioner(
1447                hardtanh_model, (test_input,), func_name=inspect.stack()[0].function[5:]
1448            )
1449
1450        for test_input in inputs:
1451            hardtanh_model = HardTanhModule(-2, 2)
1452            self.lower_and_test_with_partitioner(
1453                hardtanh_model, (test_input,), func_name=inspect.stack()[0].function[5:]
1454            )
1455
1456    def test_mps_backend_Relu(self):
1457        class ReluModule(torch.nn.Module):
1458            def __init__(self):
1459                super().__init__()
1460                self.relu = torch.nn.ReLU()
1461
1462            def forward(self, x):
1463                return self.relu(x)
1464
1465        example_input = torch.randn(2, 3, 4)
1466        self.lower_and_test_with_partitioner(
1467            ReluModule(), (example_input,), func_name=inspect.stack()[0].function[5:]
1468        )
1469
1470    def test_mps_backend_GELU(self):
1471        class GELUModule(torch.nn.Module):
1472            def __init__(self):
1473                super().__init__()
1474                self.gelu = torch.nn.GELU()
1475                self.gelu_tanh = torch.nn.GELU(approximate="tanh")
1476
1477            def forward(self, x):
1478                return self.gelu(x)
1479                # MPS TODO: MPS Gelu tanh fails
1480                # return self.gelu_tanh(y)
1481
1482        example_input = torch.randn(2, 3, 4)
1483        self.lower_and_test_with_partitioner(
1484            GELUModule(), (example_input,), func_name=inspect.stack()[0].function[5:]
1485        )
1486
1487    def test_mps_backend_leaky_Relu(self):
1488        class LeakyReluModule(torch.nn.Module):
1489            def __init__(self):
1490                super().__init__()
1491                self.leaky_relu = torch.nn.LeakyReLU()
1492                self.leaky_relu_2 = torch.nn.LeakyReLU(1.0)
1493
1494            def forward(self, x):
1495                out = self.leaky_relu(x)
1496                out = self.leaky_relu_2(out)
1497                return out
1498
1499        example_input = torch.randn(2, 3, 4)
1500        self.lower_and_test_with_partitioner(
1501            LeakyReluModule(),
1502            (example_input,),
1503            func_name=inspect.stack()[0].function[5:],
1504        )
1505
1506    def test_mps_backend_sigmoid(self):
1507        class SigmoidModule(torch.nn.Module):
1508            def __init__(self):
1509                super().__init__()
1510                self.sigmoid = torch.nn.Sigmoid()
1511
1512            def forward(self, x):
1513                return self.sigmoid(x)
1514
1515        model_inputs = (torch.rand(7, 5, 3),)
1516        sigmoid_module = SigmoidModule()
1517        self.lower_and_test_with_partitioner(
1518            sigmoid_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1519        )
1520
1521    def test_mps_backend_constant_pad_nd(self):
1522        class PadModule(torch.nn.Module):
1523            def __init__(self):
1524                super().__init__()
1525                self.constant_pad = torch.nn.ConstantPad2d((1, 2), 0)
1526
1527            def forward(self, x):
1528                return self.constant_pad(x)
1529
1530        model_inputs = (torch.rand(1, 2, 3, 4),)
1531        pad_module = PadModule()
1532        self.lower_and_test_with_partitioner(
1533            pad_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1534        )
1535
1536    def test_mps_backend_index_select(self):
1537        class IndexSelectModule(torch.nn.Module):
1538            def __init__(self):
1539                super().__init__()
1540
1541            def forward(self, input, index):
1542                return torch.index_select(input, dim=2, index=index)
1543
1544        model_inputs = (torch.rand(2, 8, 4, 5), torch.tensor([3, 0, 1]))
1545        module = IndexSelectModule()
1546        self.lower_and_test_with_partitioner(
1547            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1548        )
1549
1550    def test_mps_backend_empty(self):
1551        class EmptyModule(torch.nn.Module):
1552            def __init__(self):
1553                super().__init__()
1554
1555            def forward(self):
1556                return torch.empty((3, 4, 5), dtype=torch.float32)
1557
1558        self.lower_and_test_with_partitioner(
1559            EmptyModule(), (), func_name=inspect.stack()[0].function[5:]
1560        )
1561
1562    def test_mps_backend_static_constant_pad(self):
1563        class StaticConstantPadModule(torch.nn.Module):
1564            def __init__(self):
1565                super().__init__()
1566
1567            def forward(self, x, y, z):
1568                pad_6 = (1, 2, 3, 4, 5, 6)
1569                pad_4 = (1, 2, 3, 4)
1570                pad_2 = (1, 2)
1571                a = torch.nn.functional.pad(
1572                    input=x,
1573                    pad=pad_6,
1574                    mode="constant",
1575                    value=2.3,
1576                )
1577                b = torch.nn.functional.pad(
1578                    input=x,
1579                    pad=pad_4,
1580                    mode="constant",
1581                    value=1.3,
1582                )
1583                c = torch.nn.functional.pad(
1584                    input=x,
1585                    pad=pad_2,
1586                    mode="constant",
1587                    value=2.1,
1588                )
1589                d = torch.nn.functional.pad(
1590                    input=y,
1591                    pad=pad_6,
1592                    mode="constant",
1593                    value=2.7,
1594                )
1595                e = torch.nn.functional.pad(
1596                    input=y,
1597                    pad=pad_4,
1598                    mode="constant",
1599                    value=1.9,
1600                )
1601                f = torch.nn.functional.pad(
1602                    input=y,
1603                    pad=pad_2,
1604                    mode="constant",
1605                    value=3.1,
1606                )
1607                g = torch.nn.functional.pad(
1608                    input=z,
1609                    pad=pad_4,
1610                    mode="constant",
1611                    value=2.9,
1612                )
1613                h = torch.nn.functional.pad(
1614                    input=z,
1615                    pad=pad_2,
1616                    mode="constant",
1617                    value=1.2,
1618                )
1619                return (a, b, c, d, e, f, g, h)
1620
1621        example_inputs = (
1622            torch.randn(size=(5, 4, 3, 2)),
1623            torch.randn(size=(5, 3, 2)),
1624            torch.randn(size=(4, 3)),
1625        )
1626        self.lower_and_test_with_partitioner(
1627            StaticConstantPadModule(),
1628            example_inputs,
1629            func_name=inspect.stack()[0].function[5:],
1630        )
1631
1632    def test_mps_clamp_min_max(self):
1633        class Clamp(torch.nn.Module):
1634            def __init__(self, min_val, max_val):
1635                super().__init__()
1636                self.clamp = torch.clamp
1637                self.min_val = min_val
1638                self.max_val = max_val
1639
1640            def forward(self, *x):
1641                out1 = self.clamp(x[0], min=-0.5, max=0.5)
1642                out2 = self.clamp(x[0], min=-5, max=5)
1643                return out1, out2
1644
1645        model_inputs = (
1646            torch.randn(1, 4, 122, 122) * 2,
1647            torch.randint(-100, 100, (1, 4, 15, 20)),
1648        )
1649        module = Clamp(-0.5, 0.5)
1650        self.lower_and_test_with_partitioner(
1651            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1652        )
1653
1654    def test_mps_clamp_min(self):
1655        class Clamp(torch.nn.Module):
1656            def __init__(self, min_val, max_val):
1657                super().__init__()
1658                self.clamp = torch.clamp
1659                self.min_val = min_val
1660                self.max_val = max_val
1661
1662            def forward(self, x):
1663                return self.clamp(x, min=self.min_val, max=self.max_val)
1664
1665        model_inputs = (torch.randn(1, 4, 122, 122) * 2,)
1666        module = Clamp(-0.5, None)
1667        self.lower_and_test_with_partitioner(
1668            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1669        )
1670
1671    def test_mps_clamp_max(self):
1672        class Clamp(torch.nn.Module):
1673            def __init__(self, min_val, max_val):
1674                super().__init__()
1675                self.clamp = torch.clamp
1676                self.min_val = min_val
1677                self.max_val = max_val
1678
1679            def forward(self, x):
1680                return self.clamp(x, min=self.min_val, max=self.max_val)
1681
1682        model_inputs = (torch.randn(1, 4, 122, 122) * 2,)
1683        module = Clamp(None, 0.5)
1684        self.lower_and_test_with_partitioner(
1685            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1686        )
1687
1688    def test_mps_backend_maxpool2d_default(self):
1689        class MaxPool2dModule(torch.nn.Module):
1690            def __init__(
1691                self,
1692                kernel_size=3,
1693                stride=1,
1694                padding=3,
1695                dilation=1,
1696            ):
1697                super().__init__()
1698                self.max_pool2d_module = torch.nn.MaxPool2d(
1699                    kernel_size=kernel_size,
1700                    stride=stride,
1701                    padding=padding,
1702                    dilation=dilation,
1703                )
1704
1705            def forward(self, x):
1706                return self.max_pool2d_module(x)
1707
1708        maxpool2d_module = MaxPool2dModule(3, 1, 0, 1)
1709        model_inputs = (torch.randn(4, 3, 24, 24),)
1710
1711        self.lower_and_test_with_partitioner(
1712            maxpool2d_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1713        )
1714
1715    def test_mps_backend_maxpool2d_unsupported(self):
1716        class MaxPool2dModule(torch.nn.Module):
1717            def __init__(
1718                self,
1719                kernel_size=3,
1720                stride=1,
1721                padding=0,
1722                dilation=1,
1723            ):
1724                super().__init__()
1725                self.max_pool2d_module = torch.nn.MaxPool2d(
1726                    kernel_size=kernel_size,
1727                    stride=stride,
1728                    padding=padding,
1729                    dilation=dilation,
1730                    return_indices=True,
1731                )
1732
1733            def forward(self, x):
1734                return self.max_pool2d_module(x)[1]
1735
1736        maxpool2d_module = MaxPool2dModule(3, 1, 0, 1)
1737        model_inputs = (torch.randn(4, 3, 24, 24),)
1738
1739        self.lower_and_test_with_partitioner(
1740            maxpool2d_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1741        )
1742
1743    def test_mps_backend_max_dim_vals(self):
1744        class MaxModule(torch.nn.Module):
1745            def __init__(
1746                self,
1747            ):
1748                super().__init__()
1749
1750            def forward(self, x):
1751                max_vals, _ = torch.max(x, dim=3, keepdim=True)
1752                return max_vals
1753
1754        model_inputs = (torch.randn(16, 3, 12, 12),)
1755        max_dim_module = MaxModule()
1756
1757        self.lower_and_test_with_partitioner(
1758            max_dim_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1759        )
1760
1761    def test_mps_backend_max_dim(self):
1762        class MaxModule(torch.nn.Module):
1763            def __init__(
1764                self,
1765            ):
1766                super().__init__()
1767
1768            def forward(self, x):
1769                x = torch.add(x, x)
1770                max_values_1, max_indices_1 = torch.max(x, dim=2, keepdim=True)
1771                max_values_2, max_indices_2 = torch.max(x, dim=3, keepdim=True)
1772                return (max_values_1, max_indices_1, max_values_2, max_indices_2)
1773
1774        model_inputs = (torch.randn(16, 3, 12, 12),)
1775        max_dim_module = MaxModule()
1776
1777        self.lower_and_test_with_partitioner(
1778            max_dim_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1779        )
1780
1781    def test_mps_backend_multiply(self):
1782        class MulModule(torch.nn.Module):
1783            def __init__(
1784                self,
1785            ):
1786                super().__init__()
1787                self.mul = torch.mul
1788
1789            def forward(self, x, y):
1790                return self.mul(x, y)
1791
1792        mul_module = MulModule()
1793        model_inputs = (
1794            torch.randn((1, 8)),
1795            torch.randn((8, 1)),
1796        )
1797
1798        self.lower_and_test_with_partitioner(
1799            mul_module, model_inputs, func_name=inspect.stack()[0].function[5:]
1800        )
1801
1802    def test_mps_backend_sub(self):
1803        class Sub(torch.nn.Module):
1804            def __init__(self):
1805                super().__init__()
1806                self.sub = torch.sub
1807
1808            def forward(self, x, y):
1809                return self.sub(x, y)
1810
1811        module = Sub()
1812        M = torch.randn(2, 3)
1813        N = torch.randn(2, 3)
1814        model_inputs = (
1815            M,
1816            N,
1817        )
1818        self.lower_and_test_with_partitioner(
1819            module, model_inputs, func_name=inspect.stack()[0].function[5:]
1820        )
1821
1822    def test_mps_backend_clone(self):
1823        class Clone(torch.nn.Module):
1824            def forward(self, x):
1825                return torch.clone(x)
1826
1827        model_inputs = (torch.randn(1, 3, 3),)
1828        self.lower_and_test_with_partitioner(
1829            Clone(), model_inputs, func_name=inspect.stack()[0].function[5:]
1830        )
1831
1832    def test_mps_backend_floor(self):
1833        class Floor(torch.nn.Module):
1834            def forward(self, x):
1835                return torch.floor(x)
1836
1837        model_inputs = (torch.randn(1, 3, 3),)
1838        self.lower_and_test_with_partitioner(
1839            Floor(), model_inputs, func_name=inspect.stack()[0].function[5:]
1840        )
1841
1842    def test_mps_backend_sqrt(self):
1843        class Sqrt(torch.nn.Module):
1844            def forward(self, x):
1845                return torch.sqrt(x)
1846
1847        model_inputs = (torch.randn(1, 3, 3).abs(),)
1848        self.lower_and_test_with_partitioner(
1849            Sqrt(), model_inputs, func_name=inspect.stack()[0].function[5:]
1850        )
1851
1852    def test_mps_backend_ceil(self):
1853        class Ceil(torch.nn.Module):
1854            def forward(self, x):
1855                return torch.ceil(x)
1856
1857        model_inputs = (torch.randn(1, 3, 3),)
1858        self.lower_and_test_with_partitioner(
1859            Ceil(), model_inputs, func_name=inspect.stack()[0].function[5:]
1860        )
1861
1862    def test_mps_backend_hardswish(self):
1863        model_inputs = (torch.randn(1, 3, 3),)
1864
1865        class HardswishModule(torch.nn.Module):
1866            def __init__(self):
1867                super(HardswishModule, self).__init__()
1868                self.hardswish_out_of_place = torch.nn.Hardswish()
1869                self.hardswish_in_place = torch.nn.Hardswish(inplace=True)
1870                self.hardswish_functional = torch.nn.functional.hardswish
1871
1872            def forward(self, x):
1873                a = self.hardswish_out_of_place(x)
1874                a = self.hardswish_in_place(a)
1875                a = self.hardswish_functional(a)
1876                return a
1877
1878        # TODO(T158969708)
1879        self.lower_and_test_with_partitioner(
1880            HardswishModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
1881        )
1882
1883    def test_mps_backend_leaky_relu(self):
1884        model_inputs = (torch.randn(1, 3, 3),)
1885
1886        class LeakyReLUModule(torch.nn.Module):
1887            def __init__(self):
1888                super(LeakyReLUModule, self).__init__()
1889                self.leaky_relu_out_of_place = torch.nn.LeakyReLU(negative_slope=0.2)
1890                self.leaky_relu_in_place = torch.nn.LeakyReLU(
1891                    negative_slope=0.08, inplace=True
1892                )
1893                self.leaky_relu_functional_default = torch.nn.functional.leaky_relu
1894
1895            def forward(self, x):
1896                a = self.leaky_relu_out_of_place(x)
1897                a = self.leaky_relu_in_place(a)
1898                a = self.leaky_relu_functional_default(a)
1899                return a
1900
1901        self.lower_and_test_with_partitioner(
1902            LeakyReLUModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
1903        )
1904
1905    @unittest.skip
1906    def test_mps_channels_last_tagged_reshape_pass_output(self):
1907        op_sequences = OpSequencesAddConv2d(2, 2)
1908        op_sequences.eval()
1909
1910        example_inputs = (torch.ones(1, 1, 6, 6),)
1911
1912        self.lower_and_test_with_partitioner(
1913            op_sequences, example_inputs, func_name=inspect.stack()[0].function[5:]
1914        )
1915
1916    def test_mps_backend_conv2d_bn_hardtanh_mean_sequence(self):
1917        """
1918        This test makes sure that we can fuse batchnorm and hardtanh
1919        even with inserting copy nodes at some spots in the graph to change
1920        memory format
1921        """
1922        groups = 1
1923        stride = [2, 2]
1924        padding = [1, 1]
1925        dilation = [1, 1]
1926        in_channels = 2
1927        out_channels = 1
1928        width = 8
1929        height = 8
1930        batches = 1
1931        example_inputs = (torch.randn(batches, in_channels, height, width),)
1932
1933        class TestModule(torch.nn.Module):
1934            def __init__(self):
1935                super(TestModule, self).__init__()
1936                self.conv = torch.nn.Conv2d(
1937                    in_channels=in_channels,
1938                    out_channels=out_channels,
1939                    kernel_size=(3, 3),
1940                    stride=stride,
1941                    padding=padding,
1942                    groups=groups,
1943                    dilation=dilation,
1944                    bias=True,
1945                )
1946                self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
1947                self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
1948
1949            def forward(self, x):
1950                x = self.conv(x)
1951                x = self.native_batchnorm(x)
1952                x = self.hardtanh(x)
1953                x = torch.mean(x, (-1, -2), keepdim=True)
1954                return x
1955
1956        test_module = TestModule()
1957        test_module.eval()
1958        self.lower_and_test_with_partitioner(
1959            test_module, example_inputs, func_name=inspect.stack()[0].function[5:]
1960        )
1961
1962    @unittest.expectedFailure
1963    def test_mps_backend_maximum_no_broadcast(self):
1964        model_inputs_no_broadcast = (torch.randn(2, 3, 4), torch.randn(2, 3, 4))
1965
1966        self.lower_and_test_with_partitioner(
1967            torch.maximum,
1968            model_inputs_no_broadcast,
1969            func_name=inspect.stack()[0].function[5:],
1970        )
1971
1972    @unittest.expectedFailure
1973    def test_mps_backend_maximum_broadcast(self):
1974        model_inputs_broadcast = (torch.randn(2, 3, 4), torch.randn(2, 1, 4))
1975
1976        self.lower_and_test_with_partitioner(
1977            torch.maximum,
1978            model_inputs_broadcast,
1979            func_name=inspect.stack()[0].function[5:],
1980        )
1981
1982    def test_mps_backend_negative(self):
1983        model_inputs = (torch.randn(1, 3, 3),)
1984
1985        class NegModule(torch.nn.Module):
1986            def __init__(self):
1987                super().__init__()
1988
1989            def forward(self, x):
1990                return torch.neg(x)
1991
1992        self.lower_and_test_with_partitioner(
1993            NegModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
1994        )
1995
1996    def test_mps_backend_remainder_1(self):
1997        model_inputs = (torch.randn(1, 3, 3), torch.randn(1, 3, 3))
1998
1999        class RemainderModule(torch.nn.Module):
2000            def __init__(self):
2001                super().__init__()
2002
2003            def forward(self, x, y):
2004                return torch.remainder(x, y)
2005
2006        self.lower_and_test_with_partitioner(
2007            RemainderModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2008        )
2009
2010    def test_mps_backend_remainder_2(self):
2011        model_inputs = (torch.randn(1, 3, 3),)
2012
2013        class RemainderModule(torch.nn.Module):
2014            def __init__(self):
2015                super().__init__()
2016
2017            def forward(self, x):
2018                return torch.remainder(x, 0.5)
2019
2020        self.lower_and_test_with_partitioner(
2021            RemainderModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2022        )
2023
2024    def test_mps_backend_square(self):
2025        model_inputs = (torch.randn(1, 3, 3),)
2026
2027        class SquareModule(torch.nn.Module):
2028            def __init__(self):
2029                super().__init__()
2030
2031            def forward(self, x):
2032                return torch.square(x)
2033
2034        self.lower_and_test_with_partitioner(
2035            SquareModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2036        )
2037
2038    def test_mps_backend_pow_1(self):
2039        model_inputs = (torch.randn(1, 3, 3),)
2040
2041        class PowModule(torch.nn.Module):
2042            def __init__(self):
2043                super().__init__()
2044
2045            def forward(self, x):
2046                return torch.pow(x, 4)
2047
2048        self.lower_and_test_with_partitioner(
2049            PowModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2050        )
2051
2052    def test_mps_backend_pow_2(self):
2053        model_inputs = (torch.randn(1, 3, 3), torch.tensor(4))
2054
2055        class PowModule(torch.nn.Module):
2056            def __init__(self):
2057                super().__init__()
2058
2059            def forward(self, x, y):
2060                return torch.pow(x, y)
2061
2062        self.lower_and_test_with_partitioner(
2063            PowModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2064        )
2065
2066    def test_mps_backend_elu(self):
2067        model_inputs = (torch.randn(1, 3, 3),)
2068
2069        class ELUModule(torch.nn.Module):
2070            def __init__(self):
2071                super().__init__()
2072
2073            def forward(self, x):
2074                return torch.square(x)
2075
2076        self.lower_and_test_with_partitioner(
2077            ELUModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2078        )
2079
2080    def test_mps_backend_avg_pool_2d_1(self):
2081        model_inputs = (torch.randn(1, 1, 10, 10),)
2082
2083        class AvgPoolModule(torch.nn.Module):
2084            def __init__(self):
2085                super().__init__()
2086                self.avgPool = torch.nn.AvgPool2d(
2087                    kernel_size=(2, 2),
2088                    padding=(1, 1),
2089                    stride=(2, 2),
2090                    count_include_pad=False,
2091                )
2092
2093            def forward(self, x):
2094                return self.avgPool(x)
2095
2096        self.lower_and_test_with_partitioner(
2097            AvgPoolModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2098        )
2099
2100    def test_mps_backend_avg_pool_2d_2(self):
2101        model_inputs = (torch.randn(1, 1, 10, 10),)
2102
2103        class AvgPoolModule(torch.nn.Module):
2104            def __init__(self):
2105                super().__init__()
2106                self.avgPool = torch.nn.AvgPool2d(
2107                    kernel_size=(2, 2),
2108                    padding=(1, 1),
2109                    stride=(2, 2),
2110                    count_include_pad=True,
2111                )
2112
2113            def forward(self, x):
2114                return self.avgPool(x)
2115
2116        self.lower_and_test_with_partitioner(
2117            AvgPoolModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2118        )
2119
2120    def test_mps_backend_avg_pool_2d_3(self):
2121        model_inputs = (torch.randn(1, 1, 10, 10),)
2122
2123        class AvgPoolModule(torch.nn.Module):
2124            def __init__(self):
2125                super().__init__()
2126                self.avgPool = torch.nn.AvgPool2d(
2127                    kernel_size=(2, 2),
2128                    padding=(1, 1),
2129                    stride=(2, 2),
2130                    count_include_pad=False,
2131                    ceil_mode=True,
2132                    divisor_override=4,
2133                )
2134
2135            def forward(self, x):
2136                return self.avgPool(x)
2137
2138        self.lower_and_test_with_partitioner(
2139            AvgPoolModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2140        )
2141
2142    def test_mps_backend_abs(self):
2143        model_inputs = (torch.randn(1, 3, 3),)
2144
2145        class AbsModule(torch.nn.Module):
2146            def forward(self, x):
2147                return torch.abs(x)
2148
2149        self.lower_and_test_with_partitioner(
2150            AbsModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2151        )
2152
2153    def test_mps_backend_sign(self):
2154        model_inputs = (torch.randn(1, 3, 3),)
2155
2156        class SignModule(torch.nn.Module):
2157            def forward(self, x):
2158                return torch.sign(x)
2159
2160        self.lower_and_test_with_partitioner(
2161            SignModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2162        )
2163
2164    def test_mps_backend_rsqrt(self):
2165        model_inputs = (torch.randn(1, 3, 3).abs(),)
2166
2167        class RsqrtModule(torch.nn.Module):
2168            def forward(self, x):
2169                return torch.rsqrt(x)
2170
2171        self.lower_and_test_with_partitioner(
2172            RsqrtModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2173        )
2174
2175    def test_mps_backend_prelu(self):
2176        num_channels = 5
2177        model_inputs = (torch.randn(1, num_channels, 3, 2),)
2178
2179        class PReLUModule(torch.nn.Module):
2180            def __init__(self):
2181                super(PReLUModule, self).__init__()
2182                self.prelu = torch.nn.PReLU()
2183                self.prelu_non_default = torch.nn.PReLU(
2184                    num_parameters=num_channels, init=0.2
2185                )
2186
2187            def forward(self, x):
2188                a = self.prelu(x)
2189                a = self.prelu_non_default(a)
2190                return a
2191
2192        self.lower_and_test_with_partitioner(
2193            PReLUModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2194        )
2195
2196        # Should fail to be partitioned since constraint (input dim) is violated
2197        self.assertRaises(
2198            Exception,
2199            self.lower_and_test_with_partitioner,
2200            torch.nn.PReLU(),
2201            (torch.randn(1, 2),),
2202        )
2203
2204    def test_mps_backend_concatenate2(self):
2205        class Concat(torch.nn.Module):
2206            def forward(self, x, y):
2207                return torch.cat((y, x), 0)
2208
2209        self.lower_and_test_with_partitioner(
2210            Concat(),
2211            (torch.ones(4, 2, 3), torch.randn(1, 2, 3)),
2212            func_name=inspect.stack()[0].function[5:],
2213        )
2214
2215    def test_mps_backend_concatenate3(self):
2216        class Concat(torch.nn.Module):
2217            def forward(self, x, y):
2218                return torch.concat((y, y, x), 0)
2219
2220        self.lower_and_test_with_partitioner(
2221            Concat(),
2222            (torch.ones(4, 2, 3), torch.randn(1, 2, 3)),
2223            func_name=inspect.stack()[0].function[5:],
2224        )
2225
2226    def test_mps_backend_concatenate4(self):
2227        class Concat(torch.nn.Module):
2228            def forward(self, x, y):
2229                return torch.concatenate((y, x, y, x), 2)
2230
2231        self.lower_and_test_with_partitioner(
2232            Concat(),
2233            (torch.randn(1, 2, 3), torch.randn(1, 2, 5)),
2234            func_name=inspect.stack()[0].function[5:],
2235        )
2236
2237    def test_mps_backend_concatenate_nhwc(self):
2238        class Concat(torch.nn.Module):
2239            def __init__(self):
2240                super().__init__()
2241                self.conv = torch.nn.Conv2d(
2242                    in_channels=1,
2243                    out_channels=3,
2244                    kernel_size=(3, 3),
2245                    padding=1,
2246                    bias=False,
2247                )
2248
2249            def forward(self, x, y):
2250                x = self.conv(x)
2251                return torch.concatenate((y, x, y, x), 1)
2252
2253        self.lower_and_test_with_partitioner(
2254            Concat(),
2255            (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)),
2256            func_name=inspect.stack()[0].function[5:],
2257        )
2258
2259    def test_mps_backend_concatenate_nhwc2(self):
2260        class Concat(torch.nn.Module):
2261            def __init__(self):
2262                super().__init__()
2263                self.conv = torch.nn.Conv2d(
2264                    in_channels=1,
2265                    out_channels=3,
2266                    kernel_size=(3, 3),
2267                    padding=1,
2268                    bias=False,
2269                )
2270
2271            def forward(self, x, y):
2272                x = self.conv(x)
2273                y = self.conv(y)
2274                return torch.concatenate((y, x, y, x), 3)
2275
2276        self.lower_and_test_with_partitioner(
2277            Concat(),
2278            (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)),
2279            func_name=inspect.stack()[0].function[5:],
2280        )
2281
2282    def test_mps_backend_slice_copy(self):
2283        class Slice(torch.nn.Module):
2284            def forward(self, x):
2285                return x[1:3, -2:, :-1]
2286
2287        self.lower_and_test_with_partitioner(
2288            Slice(), (torch.randn(5, 5, 5),), func_name=inspect.stack()[0].function[5:]
2289        )
2290
2291    def test_mps_backend_slice_copy_stride_non_1(self):
2292        class Slice(torch.nn.Module):
2293            def forward(self, x):
2294                return x[:3:-1, 2:, :3]
2295
2296        self.assertRaises(
2297            Exception,
2298            self.lower_and_test_with_partitioner,
2299            Slice(),
2300            (torch.randn(5, 5, 5),),
2301            func_name=inspect.stack()[0].function[5:],
2302        )
2303
2304    def test_mps_backend_slice_copy_dim_0(self):
2305        class Slice(torch.nn.Module):
2306            def forward(self, x):
2307                return x[-1:3, 2:, 3:3]
2308
2309        self.lower_module_and_test_output(
2310            Slice(),
2311            (torch.randn(5, 5, 5),),
2312            use_partitioner=False,
2313            func_name=inspect.stack()[0].function[5:],
2314        )
2315
2316    def test_mps_backend_slice_copy_memory_format(self):
2317        class ConvSlice(torch.nn.Module):
2318            def __init__(self):
2319                super().__init__()
2320                self.conv = torch.nn.Conv2d(
2321                    in_channels=1,
2322                    out_channels=3,
2323                    kernel_size=(3, 3),
2324                    padding=1,
2325                    bias=False,
2326                )
2327
2328            def forward(self, x):
2329                y = self.conv(x)
2330                return y[:, :, 2:3, -2:]
2331
2332        self.lower_and_test_with_partitioner(
2333            ConvSlice(),
2334            (torch.randn(1, 1, 3, 3),),
2335            func_name=inspect.stack()[0].function[5:],
2336        )
2337
2338    def test_mps_backend_bitwise_and(self):
2339        class BitwiseAnd(torch.nn.Module):
2340            def forward(self, x, y):
2341                return torch.bitwise_and(x, y)
2342
2343        model_inputs = (
2344            torch.tensor([-1, -2, 3], dtype=torch.int8),
2345            torch.tensor([1, 0, 3], dtype=torch.int8),
2346        )
2347        self.lower_and_test_with_partitioner(
2348            BitwiseAnd(), model_inputs, func_name=inspect.stack()[0].function[5:]
2349        )
2350
2351    def test_mps_backend_bitwise_or(self):
2352        class BitwiseOr(torch.nn.Module):
2353            def forward(self, x, y):
2354                return torch.bitwise_or(x, y)
2355
2356        model_inputs = (
2357            torch.tensor([-1, -2, 3], dtype=torch.int8),
2358            torch.tensor([1, 0, 3], dtype=torch.int8),
2359        )
2360        self.lower_and_test_with_partitioner(
2361            BitwiseOr(), model_inputs, func_name=inspect.stack()[0].function[5:]
2362        )
2363
2364    def test_mps_backend_bitwise_xor(self):
2365        class BitwiseXor(torch.nn.Module):
2366            def forward(self, x, y):
2367                return torch.bitwise_xor(x, y)
2368
2369        model_inputs = (
2370            torch.tensor([True, True, False]),
2371            torch.tensor([False, True, False]),
2372        )
2373        self.lower_and_test_with_partitioner(
2374            BitwiseXor(), model_inputs, func_name=inspect.stack()[0].function[5:]
2375        )
2376
2377    def test_mps_backend_bitwise_not(self):
2378        class BitwiseNot(torch.nn.Module):
2379            def forward(self, x):
2380                return torch.bitwise_not(x)
2381
2382        model_inputs = (torch.tensor([-1, -2, 3], dtype=torch.int8),)
2383        self.lower_and_test_with_partitioner(
2384            BitwiseNot(), model_inputs, func_name=inspect.stack()[0].function[5:]
2385        )
2386
2387    def test_mps_backend_bitwise_not_with_bool(self):
2388        class BitwiseNot(torch.nn.Module):
2389            def forward(self, x):
2390                return torch.bitwise_not(x)
2391
2392        model_inputs = (torch.tensor([True, True, False]),)
2393        self.lower_and_test_with_partitioner(
2394            BitwiseNot(), model_inputs, func_name=inspect.stack()[0].function[5:]
2395        )
2396
2397    def test_mps_backend_bitwise_with_scalar(self):
2398        class BitwiseScalarModule(torch.nn.Module):
2399            def __init__(self):
2400                super().__init__()
2401                self._scalar = 3
2402
2403            def forward(self, x):
2404                out1 = torch.ops.aten.bitwise_and.Scalar(x, self._scalar)
2405                return out1
2406
2407        model_inputs = (torch.tensor([-1, -2, 3], dtype=torch.int8),)
2408        self.lower_and_test_with_partitioner(
2409            BitwiseScalarModule(),
2410            model_inputs,
2411            func_name=inspect.stack()[0].function[5:],
2412        )
2413
2414    def test_mps_backend_arange(self):
2415        class ArangeModule(torch.nn.Module):
2416            def __init__(self):
2417                super().__init__()
2418                self._begin = 2.5
2419                self._end = 5
2420                self._step = 0.5
2421
2422            def forward(self):
2423                out1 = torch.arange(end=self._end)
2424                out2 = torch.arange(start=self._begin, end=self._end, step=self._step)
2425                return out1 + out2
2426
2427        self.lower_and_test_with_partitioner(
2428            ArangeModule(), (), func_name=inspect.stack()[0].function[5:]
2429        )
2430
2431    def test_mps_backend_where(self):
2432        class Where(torch.nn.Module):
2433            def forward(self, cond, x, y):
2434                return torch.where(cond, x, y)
2435
2436        x = torch.randn(3, 2)
2437        y = torch.ones(3, 2)
2438        cond = x > 0
2439        module_inputs = (cond, x, y)
2440        self.lower_and_test_with_partitioner(
2441            Where(), module_inputs, func_name=inspect.stack()[0].function[5:]
2442        )
2443
2444    def test_mps_backend_scalar_tensor(self):
2445        class ScalarTensorModule(torch.nn.Module):
2446            def __init__(self):
2447                super().__init__()
2448                self._scalar = 3.0
2449                self._bool = True
2450
2451            def forward(self):
2452                out1 = torch.ops.aten.scalar_tensor(self._scalar)
2453                out2 = torch.ops.aten.scalar_tensor(self._scalar, dtype=torch.int32)
2454                # issue 121117206
2455                out3 = torch.ops.aten.scalar_tensor(self._bool, dtype=torch.bool)
2456                return out1 + out2 + out3
2457
2458        self.lower_and_test_with_partitioner(
2459            ScalarTensorModule(), (), func_name=inspect.stack()[0].function[5:]
2460        )
2461
2462    def test_mps_backend_tril(self):
2463        class TrilModule(torch.nn.Module):
2464            def __init__(self):
2465                super().__init__()
2466                self._k = 1
2467                self._negK = -1
2468
2469            def forward(self, x):
2470                out1 = torch.tril(x, diagonal=self._k)
2471                out2 = torch.tril(x, diagonal=self._negK)
2472                return out1 + out2
2473
2474        model_inputs = (torch.randn(4, 6),)
2475        self.lower_and_test_with_partitioner(
2476            TrilModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2477        )
2478
2479    def test_mps_backend_embedding(self):
2480        class EmbeddingModule(torch.nn.Module):
2481            def __init__(self):
2482                super().__init__()
2483                self._embedding = torch.nn.Embedding(10, 3)
2484                self._embedding_with_padding = torch.nn.Embedding(10, 3, padding_idx=2)
2485
2486            def forward(self, x):
2487                return self._embedding(x) + self._embedding_with_padding(x)
2488
2489        model_inputs = (torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]),)
2490        self.lower_and_test_with_partitioner(
2491            EmbeddingModule(), model_inputs, func_name=inspect.stack()[0].function[5:]
2492        )
2493