1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import inspect 7import logging 8import random 9import unittest 10from enum import Enum 11 12import torch 13from executorch.backends.apple.mps.test.test_mps_models import MPS_MODEL_NAME_TO_MODEL 14from executorch.backends.apple.mps.test.test_mps_utils import ( 15 OpSequencesAddConv2d, 16 randomize_bn, 17 TestMPS, 18) 19from executorch.examples.models import MODEL_NAME_TO_MODEL 20from executorch.examples.models.model_factory import EagerModelFactory 21 22from executorch.exir.tests.models import ( 23 BasicSinMax, 24 CompositeDelegateModule, 25 ElementwiseAdd, 26 Emformer, 27 MLP, 28 ModelWithUnusedArg, 29 Mul, 30 Repeat, 31) 32 33 34class MODEL_TYPE(Enum): 35 EXIR_DEFAULT_MODEL = 0 36 EXIR_TEST_MODEL = 1 37 MPS_TEST_MODEL = 2 38 39 40EXIR_MODEL_NAME_TO_MODEL = { 41 "repeat": lambda: (Repeat(), Repeat().get_random_inputs()), 42 "model_with_unused_arg": lambda: ( 43 ModelWithUnusedArg(), 44 ModelWithUnusedArg().get_random_inputs(), 45 ), 46 "mlp": lambda: (MLP(), MLP().get_random_inputs()), 47 "mul_2": lambda: (Mul(), Mul().get_random_inputs()), 48 "element_wise_add": lambda: ( 49 ElementwiseAdd(), 50 ElementwiseAdd().get_random_inputs(), 51 ), 52 "basic_sin_max": lambda: (BasicSinMax(), BasicSinMax().get_random_inputs()), 53 "composite_delegate_module": lambda: ( 54 CompositeDelegateModule(), 55 CompositeDelegateModule().get_random_inputs(), 56 ), 57 "emformer": lambda: (Emformer(), Emformer().get_random_inputs()), 58} 59 60 61def run_model( 62 model: str, 63 model_type: MODEL_TYPE = MODEL_TYPE.EXIR_DEFAULT_MODEL, 64 use_fp16: bool = False, 65 lowering_func=None, 66): 67 logging.info(f"Step 1: Retrieving model: {model}...") 68 if model_type == MODEL_TYPE.EXIR_DEFAULT_MODEL: 69 m, m_inputs = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[model]) 70 elif model_type == MODEL_TYPE.EXIR_TEST_MODEL: 71 m, m_inputs = EXIR_MODEL_NAME_TO_MODEL.get(model)() 72 elif model_type == MODEL_TYPE.MPS_TEST_MODEL: 73 m, m_inputs = MPS_MODEL_NAME_TO_MODEL.get(model)() 74 75 lowering_func(m, m_inputs, model) 76 77 78class TestMPSBackendExirModels(TestMPS): 79 def test_model_with_unused_arg(self): 80 run_model( 81 inspect.stack()[0].function[5:], 82 MODEL_TYPE.EXIR_TEST_MODEL, 83 lowering_func=self.lower_and_test_with_partitioner, 84 ) 85 86 def test_mlp(self): 87 run_model( 88 inspect.stack()[0].function[5:], 89 MODEL_TYPE.EXIR_TEST_MODEL, 90 lowering_func=self.lower_and_test_with_partitioner, 91 ) 92 93 def test_mul_2(self): 94 run_model( 95 inspect.stack()[0].function[5:], 96 MODEL_TYPE.EXIR_TEST_MODEL, 97 lowering_func=self.lower_and_test_with_partitioner, 98 ) 99 100 def test_element_wise_add(self): 101 run_model( 102 inspect.stack()[0].function[5:], 103 MODEL_TYPE.EXIR_TEST_MODEL, 104 lowering_func=self.lower_and_test_with_partitioner, 105 ) 106 107 def test_emformer(self): 108 run_model( 109 inspect.stack()[0].function[5:], 110 MODEL_TYPE.EXIR_TEST_MODEL, 111 lowering_func=self.lower_and_test_with_partitioner, 112 ) 113 114 115class TestMPSBackendMPSModels(TestMPS): 116 def test_conv2D(self): 117 run_model( 118 inspect.stack()[0].function[5:], 119 MODEL_TYPE.MPS_TEST_MODEL, 120 lowering_func=self.lower_and_test_with_partitioner, 121 ) 122 123 def test_norm(self): 124 run_model( 125 inspect.stack()[0].function[5:], 126 MODEL_TYPE.MPS_TEST_MODEL, 127 lowering_func=self.lower_and_test_with_partitioner, 128 ) 129 130 def test_module_add(self): 131 run_model( 132 inspect.stack()[0].function[5:], 133 MODEL_TYPE.MPS_TEST_MODEL, 134 lowering_func=self.lower_and_test_with_partitioner, 135 ) 136 137 def test_toy_model_for_mem_planning(self): 138 run_model( 139 inspect.stack()[0].function[5:], 140 MODEL_TYPE.MPS_TEST_MODEL, 141 lowering_func=self.lower_and_test_with_partitioner, 142 ) 143 144 def test_mem_planning_with_scratch_tensor(self): 145 run_model( 146 inspect.stack()[0].function[5:], 147 MODEL_TYPE.MPS_TEST_MODEL, 148 lowering_func=self.lower_and_test_with_partitioner, 149 ) 150 151 def test_module_ops_return_tensor_list(self): 152 run_model( 153 inspect.stack()[0].function[5:], 154 MODEL_TYPE.MPS_TEST_MODEL, 155 lowering_func=self.lower_and_test_with_partitioner, 156 ) 157 158 def test_module_contiguous_tensor(self): 159 run_model( 160 inspect.stack()[0].function[5:], 161 MODEL_TYPE.MPS_TEST_MODEL, 162 lowering_func=self.lower_and_test_with_partitioner, 163 ) 164 165 def test_module_input_dynamic_shape(self): 166 run_model( 167 inspect.stack()[0].function[5:], 168 MODEL_TYPE.MPS_TEST_MODEL, 169 lowering_func=self.lower_and_test_with_partitioner, 170 ) 171 172 173class TestMPSUnitOpTesting(TestMPS): 174 def test_mps_backend_split_copy(self): 175 class SplitCopy(torch.nn.Module): 176 def __init__(self): 177 super().__init__() 178 179 def forward(self, x): 180 return torch.split(x, 2, 1) 181 182 example_inputs = (torch.randn(3, 5, 4, 7),) 183 self.lower_and_test_with_partitioner( 184 SplitCopy(), example_inputs, func_name=inspect.stack()[0].function[5:] 185 ) 186 187 def test_mps_backend_unbind_copy(self): 188 class UnbindCopy(torch.nn.Module): 189 def __init__(self): 190 super().__init__() 191 192 def forward(self, x): 193 return torch.unbind(x, 1) 194 195 example_inputs = (torch.randn(3, 5, 4, 7),) 196 self.lower_and_test_with_partitioner( 197 UnbindCopy(), example_inputs, func_name=inspect.stack()[0].function[5:] 198 ) 199 200 def test_mps_backend_pixel_shuffle(self): 201 class PixelShuffle(torch.nn.Module): 202 def __init__(self): 203 super().__init__() 204 205 def forward(self, x): 206 return torch.pixel_shuffle(x, 2) 207 208 example_inputs = (torch.randn(3, 8, 4, 7),) 209 self.lower_and_test_with_partitioner( 210 PixelShuffle(), example_inputs, func_name=inspect.stack()[0].function[5:] 211 ) 212 213 def test_mps_backend_cumsum(self): 214 class CumulativeSum(torch.nn.Module): 215 def __init__(self): 216 super().__init__() 217 218 def forward(self, *x): 219 return torch.cumsum(x[0], dim=0) 220 221 example_inputs = (torch.randn(3, 5, 4, 7),) 222 self.lower_and_test_with_partitioner( 223 CumulativeSum(), example_inputs, func_name=inspect.stack()[0].function[5:] 224 ) 225 226 def test_mps_backend_stack(self): 227 class Stack(torch.nn.Module): 228 def __init__(self): 229 super().__init__() 230 231 def forward(self, *x): 232 return torch.stack((x), 0) 233 234 example_inputs = ( 235 torch.randn(1, 5, 1, 8), 236 torch.randn(1, 5, 1, 8), 237 ) 238 self.lower_and_test_with_partitioner( 239 Stack(), example_inputs, func_name=inspect.stack()[0].function[5:] 240 ) 241 242 def test_mps_backend_cat(self): 243 class Cat(torch.nn.Module): 244 def __init__(self): 245 super().__init__() 246 247 def forward(self, *x): 248 return torch.cat((x), 1) 249 250 example_inputs = ( 251 torch.randn(1, 5, 1, 8), 252 torch.randn(1, 5, 1, 8), 253 ) 254 self.lower_and_test_with_partitioner( 255 Cat(), example_inputs, func_name=inspect.stack()[0].function[5:] 256 ) 257 258 def test_mps_backend_expand_copy(self): 259 class ExpandCopy(torch.nn.Module): 260 def __init__(self): 261 super().__init__() 262 self.example_inputs = [7, 5, 4, 8] 263 264 def forward(self, x): 265 return x.expand(self.example_inputs) 266 267 example_inputs = (torch.randn(1, 5, 1, 8),) 268 self.lower_and_test_with_partitioner( 269 ExpandCopy(), example_inputs, func_name=inspect.stack()[0].function[5:] 270 ) 271 272 def test_mps_backend_select(self): 273 class Select(torch.nn.Module): 274 def __init__(self): 275 super().__init__() 276 277 def forward(self, x): 278 return torch.select(x, 3, 2) 279 280 example_inputs = (torch.randn(3, 5, 4, 7),) 281 self.lower_and_test_with_partitioner( 282 Select(), example_inputs, func_name=inspect.stack()[0].function[5:] 283 ) 284 285 def test_mps_backend_view_copy(self): 286 class ViewCopy(torch.nn.Module): 287 def __init__(self): 288 super().__init__() 289 self.example_inputs = [2, 10, 2, 4] 290 291 def forward(self, x): 292 return x.view(self.example_inputs) 293 294 example_inputs = (torch.randn(1, 5, 4, 8),) 295 self.lower_and_test_with_partitioner( 296 ViewCopy(), example_inputs, func_name=inspect.stack()[0].function[5:] 297 ) 298 299 def test_mps_backend_mean_dim_2(self): 300 class Mean(torch.nn.Module): 301 def __init__(self): 302 super().__init__() 303 304 def forward(self, x): 305 return torch.mean(x, (-1, -2), keepdim=True) 306 307 example_inputs = (torch.randn(1, 5, 4, 4),) 308 self.lower_and_test_with_partitioner( 309 Mean(), example_inputs, func_name=inspect.stack()[0].function[5:] 310 ) 311 312 def test_mps_backend_squeeze_dim_1(self): 313 class Squeeze(torch.nn.Module): 314 def __init__(self): 315 super().__init__() 316 317 def forward(self, x): 318 y = torch.squeeze(x, 2) 319 return torch.squeeze(y, 0) 320 321 example_inputs = (torch.randn(1, 5, 1, 1, 4),) 322 self.lower_and_test_with_partitioner( 323 Squeeze(), example_inputs, func_name=inspect.stack()[0].function[5:] 324 ) 325 326 def test_mps_backend_unsqueeze_dim_1(self): 327 class Squeeze(torch.nn.Module): 328 def __init__(self): 329 super().__init__() 330 331 def forward(self, x): 332 return torch.unsqueeze(x, 1) 333 334 example_inputs = (torch.randn(1, 5, 1, 4),) 335 self.lower_and_test_with_partitioner( 336 Squeeze(), example_inputs, func_name=inspect.stack()[0].function[5:] 337 ) 338 339 def test_mps_backend_mean_dim_no_keepdim(self): 340 class Mean(torch.nn.Module): 341 def __init__(self): 342 super().__init__() 343 344 def forward(self, x): 345 return torch.mean(x, (-1, -2), keepdim=False) 346 347 example_inputs = (torch.randn(1, 5, 4, 4),) 348 self.lower_and_test_with_partitioner( 349 Mean(), example_inputs, func_name=inspect.stack()[0].function[5:] 350 ) 351 352 def test_mps_backend_mean_dim_unsupported(self): 353 class Mean(torch.nn.Module): 354 def __init__(self): 355 super().__init__() 356 357 def forward(self, x): 358 return torch.mean(x, (3), keepdim=True) 359 360 example_inputs = (torch.randn(1, 5, 4, 4),) 361 self.lower_and_test_with_partitioner( 362 Mean(), example_inputs, func_name=inspect.stack()[0].function[5:] 363 ) 364 365 def test_mps_backend_static_transpose(self): 366 class PermuteModule(torch.nn.Module): 367 def __init__(self): 368 super().__init__() 369 self.nchw_to_nhwc = [0, 2, 3, 1] 370 371 def forward(self, x): 372 return torch.permute(x, self.nchw_to_nhwc) 373 374 example_inputs = (torch.randn(1, 1, 4, 4),) 375 self.lower_module_and_test_output( 376 PermuteModule(), example_inputs, func_name=inspect.stack()[0].function[5:] 377 ) 378 379 def test_mps_backend_sequential_conv2d(self): 380 class TwoConv(torch.nn.Module): 381 def __init__(self): 382 super().__init__() 383 self.first = torch.nn.Conv2d( 384 in_channels=1, 385 out_channels=3, 386 kernel_size=(3, 3), 387 padding=1, 388 bias=False, 389 ) 390 self.second = torch.nn.Conv2d( 391 in_channels=3, 392 out_channels=2, 393 kernel_size=(3, 3), 394 padding=1, 395 bias=False, 396 ) 397 398 def forward(self, x): 399 return self.second(self.first(x)) 400 401 example_inputs = (torch.randn(1, 1, 3, 3),) 402 self.lower_and_test_with_partitioner( 403 TwoConv(), example_inputs, func_name=inspect.stack()[0].function[5:] 404 ) 405 406 def test_mps_backend_conv2d_bn_1(self): 407 class ModelConvBN(torch.nn.Module): 408 def __init__(self, in_features: int, out_features: int, kernel_size): 409 super().__init__() 410 self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) 411 self.bn = randomize_bn(out_features) 412 413 def forward(self, x): 414 y = self.conv2d(x) 415 y = self.bn(y) 416 return y 417 418 model = ModelConvBN(2, 2, (2, 2)).eval() 419 420 self.lower_and_test_with_partitioner( 421 model, (torch.randn(2, 2, 4, 4),), func_name=inspect.stack()[0].function[5:] 422 ) 423 424 def test_mps_backend_conv2d(self): 425 groups = 1 426 stride = [2, 2] 427 padding = [1, 1] 428 dilation = [1, 1] 429 in_channels = 2 430 out_channels = 1 431 width = 8 432 height = 8 433 batches = 1 434 example_inputs = (torch.randn(batches, in_channels, height, width),) 435 conv = torch.nn.Conv2d( 436 in_channels=in_channels, 437 out_channels=out_channels, 438 kernel_size=(3, 3), 439 stride=stride, 440 padding=padding, 441 groups=groups, 442 dilation=dilation, 443 bias=True, 444 ) 445 conv.eval() 446 self.lower_and_test_with_partitioner( 447 conv, example_inputs, func_name=inspect.stack()[0].function[5:] 448 ) 449 450 def test_conv1d(self): 451 example_inputs = (torch.randn(1, 57, 40),) 452 stride = random.randint(1, 4) 453 padding = random.randint(1, 4) 454 conv = torch.nn.Conv1d( 455 57, 456 20, 457 stride=stride, 458 padding=padding, 459 kernel_size=3, 460 bias=random.choice([True, False]), 461 ) 462 conv.eval() 463 self.lower_and_test_with_partitioner( 464 conv, example_inputs, func_name=inspect.stack()[0].function[5:] 465 ) 466 467 def test_conv2d_simple(self): 468 N = 10 469 C = 10 470 H = 4 471 W = 6 472 groups = 2 473 input_memory_format = torch.contiguous_format 474 weight_memory_format = torch.contiguous_format 475 strideX = random.randint(1, 4) 476 strideY = random.randint(1, 4) 477 example_inputs = ( 478 torch.randn(N, C, H, W).to(memory_format=input_memory_format), 479 ) 480 conv = torch.nn.Conv2d( 481 in_channels=N, 482 out_channels=C, 483 kernel_size=H, 484 groups=groups, 485 stride=(strideX, strideY), 486 bias=False, 487 ) 488 conv.weight.data = conv.weight.to(memory_format=weight_memory_format) 489 conv.eval() 490 self.lower_and_test_with_partitioner( 491 conv, example_inputs, func_name=inspect.stack()[0].function[5:] 492 ) 493 494 def test_conv2d_to_depthwise_conv_3d(self): 495 N = 10 496 C = 10 497 H = 4 498 W = 6 499 groups = 10 500 input_memory_format = torch.contiguous_format 501 weight_memory_format = torch.contiguous_format 502 strideX = random.randint(1, 4) 503 strideY = random.randint(1, 4) 504 example_inputs = ( 505 torch.randn(N, C, H, W).to(memory_format=input_memory_format), 506 ) 507 conv = torch.nn.Conv2d( 508 in_channels=N, 509 out_channels=C, 510 kernel_size=H, 511 groups=groups, 512 stride=(strideX, strideY), 513 ) 514 conv.weight.data = conv.weight.to(memory_format=weight_memory_format) 515 conv.eval() 516 self.lower_and_test_with_partitioner( 517 conv, example_inputs, func_name=inspect.stack()[0].function[5:] 518 ) 519 520 def test_mps_backend_conv2d_single_int_params(self): 521 groups = 1 522 stride = 2 523 padding = "valid" 524 dilation = 1 525 in_channels = 2 526 out_channels = 1 527 width = 8 528 height = 8 529 batches = 1 530 example_inputs = (torch.randn(batches, in_channels, height, width),) 531 conv = torch.nn.Conv2d( 532 in_channels=in_channels, 533 out_channels=out_channels, 534 kernel_size=3, 535 stride=stride, 536 padding=padding, 537 groups=groups, 538 dilation=dilation, 539 bias=True, 540 ) 541 conv.eval() 542 self.lower_and_test_with_partitioner( 543 conv, example_inputs, func_name=inspect.stack()[0].function[5:] 544 ) 545 546 def test_mps_backend_conv2d_dw(self): 547 # Depthwise Convolution Requirements: 548 # - Groups must equal In Channels 549 # - Out Channels must be a positive multiple of In Channels 550 groups = 2 551 stride = [2, 2] 552 padding = [1, 1] 553 dilation = [1, 1] 554 in_channels = groups 555 out_channels = 3 * in_channels 556 width = 8 557 height = 8 558 batches = 1 559 example_inputs = (torch.randn(batches, in_channels, height, width),) 560 conv = torch.nn.Conv2d( 561 in_channels=in_channels, 562 out_channels=out_channels, 563 kernel_size=(3, 3), 564 stride=stride, 565 padding=padding, 566 groups=groups, 567 dilation=dilation, 568 bias=True, 569 ) 570 conv.eval() 571 self.lower_and_test_with_partitioner( 572 conv, example_inputs, func_name=inspect.stack()[0].function[5:] 573 ) 574 575 def test_mps_backend_mm(self): 576 in_sizes = [1, 4, 4] 577 input_sizes = [4, 37, 17] 578 output_sizes = [4, 17, 37] 579 for i, _ in enumerate(in_sizes): 580 in_size = int(in_sizes[i]) 581 input_size = int(input_sizes[i]) 582 output_size = int(output_sizes[i]) 583 linear = torch.nn.Linear(input_size, output_size, bias=False).eval() 584 example_input = (torch.randn(in_size, input_size),) 585 586 self.lower_and_test_with_partitioner( 587 linear, example_input, func_name=inspect.stack()[0].function[5:] 588 ) 589 590 def test_mps_backend_bmm(self): 591 class BmmModule(torch.nn.Module): 592 def __init__( 593 self, 594 ): 595 super().__init__() 596 self.bmm = torch.bmm 597 598 def forward(self, x, y): 599 return self.bmm(x, y) 600 601 mul_module = BmmModule() 602 model_inputs = ( 603 torch.randn((3, 1, 8)), 604 torch.randn((3, 8, 1)), 605 ) 606 607 self.lower_and_test_with_partitioner( 608 mul_module, model_inputs, func_name=inspect.stack()[0].function[5:] 609 ) 610 611 def test_mps_backend_addmm(self): 612 in_sizes = [1, 4, 4] 613 input_sizes = [4, 37, 17] 614 output_sizes = [4, 17, 37] 615 for i, _ in enumerate(in_sizes): 616 in_size = int(in_sizes[i]) 617 input_size = int(input_sizes[i]) 618 output_size = int(output_sizes[i]) 619 linear = torch.nn.Linear(input_size, output_size, bias=True).eval() 620 example_input = (torch.randn(in_size, input_size),) 621 622 self.lower_and_test_with_partitioner( 623 linear, example_input, func_name=inspect.stack()[0].function[5:] 624 ) 625 626 def test_mps_backend_full_ones_default(self): 627 class Ones(torch.nn.Module): 628 def __init__(self): 629 super().__init__() 630 631 def forward(self): 632 size = (4, 37, 17) 633 return torch.ones(size) 634 635 self.lower_and_test_with_partitioner( 636 Ones(), (), func_name=inspect.stack()[0].function[5:] 637 ) 638 639 def test_mps_backend_full_zeros_default(self): 640 class Zeros(torch.nn.Module): 641 def __init__(self): 642 super().__init__() 643 644 def forward(self): 645 size = (4, 37, 17) 646 return torch.zeros(size=size) 647 648 self.lower_and_test_with_partitioner( 649 Zeros(), (), func_name=inspect.stack()[0].function[5:] 650 ) 651 652 def test_mps_backend_full_default(self): 653 class Full(torch.nn.Module): 654 def __init__(self): 655 super().__init__() 656 657 def forward(self): 658 size = (4, 37, 17) 659 return torch.full(size=size, fill_value=2.0) 660 661 self.lower_and_test_with_partitioner( 662 Full(), (), func_name=inspect.stack()[0].function[5:] 663 ) 664 665 def test_mps_backend_full_like(self): 666 class Full_Like(torch.nn.Module): 667 def __init__(self): 668 super().__init__() 669 670 def forward(self, x): 671 return torch.full_like(x, fill_value=2.0) 672 673 const_module = Full_Like() 674 model_inputs = (torch.randn(4, 37, 17),) 675 676 self.lower_and_test_with_partitioner( 677 const_module, model_inputs, func_name=inspect.stack()[0].function[5:] 678 ) 679 680 def test_mps_backend_logit_1(self): 681 class LogitModule(torch.nn.Module): 682 def __init__(self): 683 super().__init__() 684 685 def forward(self, x): 686 z = torch.ops.aten.logit.default(x) 687 return z 688 689 logit_module = LogitModule() 690 model_inputs = (torch.rand(5),) 691 692 self.lower_and_test_with_partitioner( 693 logit_module, model_inputs, func_name=inspect.stack()[0].function[5:] 694 ) 695 696 def test_mps_backend_logit_2(self): 697 class LogitModule(torch.nn.Module): 698 def __init__(self): 699 super().__init__() 700 701 def forward(self, x): 702 z = torch.ops.aten.logit.default(x, eps=1e-6) 703 return z 704 705 logit_module = LogitModule() 706 model_inputs = (torch.rand(5),) 707 708 self.lower_and_test_with_partitioner( 709 logit_module, model_inputs, func_name=inspect.stack()[0].function[5:] 710 ) 711 712 def test_mps_backend_round(self): 713 class RoundModule(torch.nn.Module): 714 def __init__(self): 715 super().__init__() 716 717 def forward(self, x): 718 out = torch.round(x) 719 return out 720 721 module = RoundModule() 722 model_inputs = (torch.randn(5, 2),) 723 724 self.lower_and_test_with_partitioner( 725 module, model_inputs, func_name=inspect.stack()[0].function[5:] 726 ) 727 728 def test_mps_backend_amax(self): 729 class AmaxModule(torch.nn.Module): 730 def __init__(self): 731 super().__init__() 732 733 def forward(self, x): 734 out = torch.amax(x, 1) 735 return out 736 737 module = AmaxModule() 738 model_inputs = (torch.randn(2, 3, 4),) 739 740 self.lower_and_test_with_partitioner( 741 module, model_inputs, func_name=inspect.stack()[0].function[5:] 742 ) 743 744 def test_mps_backend_amin(self): 745 class AminModule(torch.nn.Module): 746 def __init__(self): 747 super().__init__() 748 749 def forward(self, x): 750 out = torch.amin(x, 1) 751 return out 752 753 module = AminModule() 754 model_inputs = (torch.randn(2, 3, 4),) 755 756 self.lower_and_test_with_partitioner( 757 module, model_inputs, func_name=inspect.stack()[0].function[5:] 758 ) 759 760 @unittest.skip 761 def test_mps_backend_min_dim(self): 762 class MinModule(torch.nn.Module): 763 def __init__(self): 764 super().__init__() 765 766 def forward(self, x): 767 out = torch.min(x, 1) 768 return out 769 770 module = MinModule() 771 model_inputs = (torch.randn(2, 3, 4),) 772 773 self.lower_and_test_with_partitioner( 774 module, model_inputs, func_name=inspect.stack()[0].function[5:] 775 ) 776 777 def test_mps_backend_argmax_1(self): 778 class ArgmaxModule(torch.nn.Module): 779 def __init__(self): 780 super().__init__() 781 782 def forward(self, x): 783 out1 = torch.argmax(x, 1) 784 return out1 785 786 module = ArgmaxModule() 787 model_inputs = (torch.randn(5, 10),) 788 789 self.lower_and_test_with_partitioner( 790 module, model_inputs, func_name=inspect.stack()[0].function[5:] 791 ) 792 793 def test_mps_backend_argmax_2(self): 794 class ArgmaxModule(torch.nn.Module): 795 def __init__(self): 796 super().__init__() 797 798 def forward(self, x): 799 out1 = torch.argmax(x) 800 return out1 801 802 module = ArgmaxModule() 803 model_inputs = (torch.randn(5, 10),) 804 805 self.lower_and_test_with_partitioner( 806 module, model_inputs, func_name=inspect.stack()[0].function[5:] 807 ) 808 809 def test_mps_backend_argmin_1(self): 810 class ArgminModule(torch.nn.Module): 811 def __init__(self): 812 super().__init__() 813 814 def forward(self, x): 815 out1 = torch.argmin(x, 1) 816 return out1 817 818 module = ArgminModule() 819 model_inputs = (torch.randn(5, 10),) 820 821 self.lower_and_test_with_partitioner( 822 module, model_inputs, func_name=inspect.stack()[0].function[5:] 823 ) 824 825 def test_mps_backend_argmin_2(self): 826 class ArgminModule(torch.nn.Module): 827 def __init__(self): 828 super().__init__() 829 830 def forward(self, x): 831 out1 = torch.argmin(x) 832 return out1 833 834 module = ArgminModule() 835 model_inputs = (torch.randn(5, 10),) 836 837 self.lower_and_test_with_partitioner( 838 module, model_inputs, func_name=inspect.stack()[0].function[5:] 839 ) 840 841 def test_mps_backend_minimum(self): 842 class MinimumModule(torch.nn.Module): 843 def __init__( 844 self, 845 ): 846 super().__init__() 847 self.minimum_module = torch.minimum 848 849 def forward(self, x, y): 850 return self.minimum_module(x, y) 851 852 module = MinimumModule() 853 model_inputs = ( 854 torch.randn(1, 3, 6), 855 torch.randn(1, 3, 6), 856 ) 857 self.lower_and_test_with_partitioner( 858 module, model_inputs, func_name=inspect.stack()[0].function[5:] 859 ) 860 861 def test_mps_backend_eq_tensor_1(self): 862 class EqModule(torch.nn.Module): 863 def __init__(self): 864 super().__init__() 865 866 def forward(self, x, y): 867 out = torch.eq(x, y) 868 return out 869 870 module = EqModule() 871 model_inputs = ( 872 torch.randn(2, 3, 4), 873 torch.randn(2, 3, 4), 874 ) 875 876 self.lower_and_test_with_partitioner( 877 module, model_inputs, func_name=inspect.stack()[0].function[5:] 878 ) 879 880 def test_mps_backend_eq_tensor_2(self): 881 class EqModule(torch.nn.Module): 882 def __init__(self): 883 super().__init__() 884 885 def forward(self, x, y): 886 out = torch.eq(x, y) 887 return out 888 889 module = EqModule() 890 input_tensor = torch.randn(2, 3, 4) 891 model_inputs = (input_tensor, input_tensor) 892 893 self.lower_and_test_with_partitioner( 894 module, model_inputs, func_name=inspect.stack()[0].function[5:] 895 ) 896 897 def test_mps_backend_eq_scalar(self): 898 class EqModule(torch.nn.Module): 899 def __init__(self): 900 super().__init__() 901 902 def forward(self, x): 903 out = torch.eq(x, 1.0) 904 return out 905 906 module = EqModule() 907 model_inputs = (torch.randn(2, 3, 4),) 908 909 self.lower_and_test_with_partitioner( 910 module, model_inputs, func_name=inspect.stack()[0].function[5:] 911 ) 912 913 def test_mps_backend_ne_tensor_1(self): 914 class NeModule(torch.nn.Module): 915 def __init__(self): 916 super().__init__() 917 918 def forward(self, x, y): 919 out = torch.ne(x, y) 920 return out 921 922 module = NeModule() 923 model_inputs = ( 924 torch.randn(2, 3, 4), 925 torch.randn(2, 3, 4), 926 ) 927 928 self.lower_and_test_with_partitioner( 929 module, model_inputs, func_name=inspect.stack()[0].function[5:] 930 ) 931 932 def test_mps_backend_ne_tensor_2(self): 933 class NeModule(torch.nn.Module): 934 def __init__(self): 935 super().__init__() 936 937 def forward(self, x, y): 938 out = torch.ne(x, y) 939 return out 940 941 module = NeModule() 942 input_tensor = torch.randn(2, 3, 4) 943 model_inputs = (input_tensor, input_tensor) 944 945 self.lower_and_test_with_partitioner( 946 module, model_inputs, func_name=inspect.stack()[0].function[5:] 947 ) 948 949 def test_mps_backend_ne_scalar(self): 950 class NeModule(torch.nn.Module): 951 def __init__(self): 952 super().__init__() 953 954 def forward(self, x): 955 out = torch.ne(x, 1.0) 956 return out 957 958 module = NeModule() 959 model_inputs = (torch.randn(2, 3, 4),) 960 961 self.lower_and_test_with_partitioner( 962 module, model_inputs, func_name=inspect.stack()[0].function[5:] 963 ) 964 965 def test_mps_backend_ge_tensor_1(self): 966 class GeModule(torch.nn.Module): 967 def __init__(self): 968 super().__init__() 969 970 def forward(self, x, y): 971 out = torch.ge(x, y) 972 return out 973 974 module = GeModule() 975 model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) 976 977 self.lower_and_test_with_partitioner( 978 module, model_inputs, func_name=inspect.stack()[0].function[5:] 979 ) 980 981 def test_mps_backend_ge_tensor_2(self): 982 class GeModule(torch.nn.Module): 983 def __init__(self): 984 super().__init__() 985 986 def forward(self, x, y): 987 out = torch.ge(x, y) 988 return out 989 990 module = GeModule() 991 992 input_tensor = torch.randn(2, 3, 4) 993 model_inputs = (input_tensor, input_tensor) 994 995 self.lower_and_test_with_partitioner( 996 module, model_inputs, func_name=inspect.stack()[0].function[5:] 997 ) 998 999 def test_mps_backend_ge_scalar(self): 1000 class GeModule(torch.nn.Module): 1001 def __init__(self): 1002 super().__init__() 1003 1004 def forward(self, x): 1005 out = torch.ge(x, 1.0) 1006 return out 1007 1008 module = GeModule() 1009 model_inputs = (torch.randn(2, 3, 4),) 1010 1011 self.lower_and_test_with_partitioner( 1012 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1013 ) 1014 1015 def test_mps_backend_gt_tensor_1(self): 1016 class GtModule(torch.nn.Module): 1017 def __init__(self): 1018 super().__init__() 1019 1020 def forward(self, x, y): 1021 out = torch.gt(x, y) 1022 return out 1023 1024 module = GtModule() 1025 model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) 1026 1027 self.lower_and_test_with_partitioner( 1028 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1029 ) 1030 1031 def test_mps_backend_gt_tensor_2(self): 1032 class GtModule(torch.nn.Module): 1033 def __init__(self): 1034 super().__init__() 1035 1036 def forward(self, x, y): 1037 out = torch.gt(x, y) 1038 return out 1039 1040 module = GtModule() 1041 input_tensor = torch.randn(2, 3, 4) 1042 model_inputs = (input_tensor, input_tensor) 1043 1044 self.lower_and_test_with_partitioner( 1045 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1046 ) 1047 1048 def test_mps_backend_gt_scalar(self): 1049 class GtModule(torch.nn.Module): 1050 def __init__(self): 1051 super().__init__() 1052 1053 def forward(self, x): 1054 out = torch.gt(x, 1.0) 1055 return out 1056 1057 module = GtModule() 1058 model_inputs = (torch.randn(2, 3, 4),) 1059 1060 self.lower_and_test_with_partitioner( 1061 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1062 ) 1063 1064 def test_mps_backend_isnan(self): 1065 class IsNanModule(torch.nn.Module): 1066 def __init__(self): 1067 super().__init__() 1068 1069 def forward(self, x): 1070 return torch.isnan(x) 1071 1072 module = IsNanModule() 1073 model_inputs = ( 1074 torch.randn(8, 3, 4, 5).index_put_( 1075 indices=[torch.tensor([random.randrange(0, 8)])], 1076 values=torch.tensor(float("nan")), 1077 ), 1078 ) 1079 self.lower_and_test_with_partitioner( 1080 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1081 ) 1082 1083 def test_mps_backend_partitioner(self): 1084 # `index.Tensor`` is not yet natively supported 1085 # It will fall back to MPSPartitioner. Once implemented, 1086 # replace the op with an unsupported one. 1087 class IndexTensorModule(torch.nn.Module): 1088 def __init__(self): 1089 super().__init__() 1090 self.indices = torch.tensor([0, 5, 2, 3]) 1091 1092 def forward(self, x): 1093 y = torch.add(x, 2.0) 1094 z = y[self.indices] 1095 r = z + x[self.indices] 1096 d = r - 2 1097 p = torch.pow(d, 4) 1098 return p / 10 1099 1100 module = IndexTensorModule() 1101 1102 model_inputs = (torch.randn(8, 3, 4, 5),) 1103 self.lower_and_test_with_partitioner( 1104 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1105 ) 1106 1107 def test_mps_indexing_get_1(self): 1108 class IndexGet(torch.nn.Module): 1109 def __init__(self): 1110 super().__init__() 1111 1112 def forward(self, x): 1113 return x[[0, 1, 2], [0, 1, 0]] 1114 1115 module = IndexGet() 1116 model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),) 1117 1118 self.lower_and_test_with_partitioner( 1119 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1120 ) 1121 1122 def test_mps_indexing_get_2(self): 1123 class IndexGet(torch.nn.Module): 1124 def __init__(self): 1125 super().__init__() 1126 1127 def forward(self, x): 1128 return x[:, [0, 4, 2]] 1129 1130 module = IndexGet() 1131 model_inputs = (torch.randn(5, 7, 3),) 1132 1133 self.lower_and_test_with_partitioner( 1134 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1135 ) 1136 1137 def test_mps_indexing_get_3(self): 1138 class IndexGet(torch.nn.Module): 1139 def __init__(self): 1140 super().__init__() 1141 1142 def forward(self, x): 1143 return x[:, [[0, 1], [4, 3]]] 1144 1145 module = IndexGet() 1146 model_inputs = (torch.randn(5, 7, 3),) 1147 1148 self.lower_and_test_with_partitioner( 1149 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1150 ) 1151 1152 def test_mps_indexing_get_4(self): 1153 class IndexGet(torch.nn.Module): 1154 def __init__(self): 1155 super().__init__() 1156 1157 def forward(self, x): 1158 return x[[0, 4, 2]] 1159 1160 module = IndexGet() 1161 model_inputs = (torch.randn(5, 7, 3),) 1162 1163 self.lower_and_test_with_partitioner( 1164 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1165 ) 1166 1167 def test_mps_indexing_get_5(self): 1168 class IndexGet(torch.nn.Module): 1169 def __init__(self): 1170 super().__init__() 1171 1172 def forward(self, x): 1173 return x[[0, 2, 1], :, 0] 1174 1175 module = IndexGet() 1176 model_inputs = (torch.ones(3, 2, 4),) 1177 1178 self.lower_and_test_with_partitioner( 1179 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1180 ) 1181 1182 def test_mps_indices2d(self): 1183 class IndexGet(torch.nn.Module): 1184 def __init__(self): 1185 super().__init__() 1186 1187 def forward(self, x, rows, columns): 1188 return x[rows, columns] 1189 1190 module = IndexGet() 1191 x = torch.arange(0, 12).resize(4, 3) 1192 rows = torch.tensor([[0, 0], [3, 3]]) 1193 columns = torch.tensor([[0, 2], [0, 2]]) 1194 model_inputs = ( 1195 x, 1196 rows, 1197 columns, 1198 ) 1199 1200 self.lower_and_test_with_partitioner( 1201 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1202 ) 1203 1204 def test_mps_slicing_using_advanced_index_for_column_0(self): 1205 class IndexGet(torch.nn.Module): 1206 def __init__(self): 1207 super().__init__() 1208 1209 def forward(self, x): 1210 return x[1:4] 1211 1212 module = IndexGet() 1213 model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) 1214 1215 self.lower_and_test_with_partitioner( 1216 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1217 ) 1218 1219 def test_mps_slicing_using_advanced_index_for_column_1(self): 1220 class IndexGet(torch.nn.Module): 1221 def __init__(self): 1222 super().__init__() 1223 1224 def forward(self, x): 1225 # using advanced index for column 1226 return x[1:4, [1, 2]] 1227 1228 module = IndexGet() 1229 model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) 1230 1231 self.lower_and_test_with_partitioner( 1232 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1233 ) 1234 1235 @unittest.skip 1236 def test_boolean_array_indexing(self): 1237 class IndexGet(torch.nn.Module): 1238 def __init__(self): 1239 super().__init__() 1240 1241 def forward(self, x): 1242 return x[x > 5] 1243 1244 module = IndexGet() 1245 model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) 1246 1247 self.lower_and_test_with_partitioner( 1248 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1249 ) 1250 1251 def test_mps_backend_isinf(self): 1252 class IsInfModule(torch.nn.Module): 1253 def __init__(self): 1254 super().__init__() 1255 1256 def forward(self, x): 1257 return torch.isinf(x) 1258 1259 module = IsInfModule() 1260 model_inputs = ( 1261 torch.randn(8, 3, 4, 5).index_put_( 1262 indices=[torch.tensor([random.randrange(0, 8)])], 1263 values=torch.tensor(float("inf")), 1264 ), 1265 ) 1266 self.lower_and_test_with_partitioner( 1267 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1268 ) 1269 1270 def test_mps_backend_le_tensor_1(self): 1271 class LeModule(torch.nn.Module): 1272 def __init__(self): 1273 super().__init__() 1274 1275 def forward(self, x, y): 1276 out = torch.le(x, y) 1277 return out 1278 1279 module = LeModule() 1280 model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) 1281 1282 self.lower_and_test_with_partitioner( 1283 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1284 ) 1285 1286 def test_mps_backend_le_tensor_2(self): 1287 class LeModule(torch.nn.Module): 1288 def __init__(self): 1289 super().__init__() 1290 1291 def forward(self, x, y): 1292 out = torch.le(x, y) 1293 return out 1294 1295 module = LeModule() 1296 input_tensor = torch.randn(2, 3, 4) 1297 model_inputs = (input_tensor, input_tensor) 1298 1299 self.lower_and_test_with_partitioner( 1300 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1301 ) 1302 1303 def test_mps_backend_le_scalar(self): 1304 class LeModule(torch.nn.Module): 1305 def __init__(self): 1306 super().__init__() 1307 1308 def forward(self, x): 1309 out = torch.le(x, 1.0) 1310 return out 1311 1312 module = LeModule() 1313 model_inputs = (torch.randn(2, 3, 4),) 1314 1315 self.lower_and_test_with_partitioner( 1316 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1317 ) 1318 1319 def test_mps_backend_lt_tensor_1(self): 1320 class LtModule(torch.nn.Module): 1321 def __init__(self): 1322 super().__init__() 1323 1324 def forward(self, x, y): 1325 out = torch.lt(x, y) 1326 return out 1327 1328 module = LtModule() 1329 model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) 1330 1331 self.lower_and_test_with_partitioner( 1332 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1333 ) 1334 1335 def test_mps_backend_lt_tensor_2(self): 1336 class LtModule(torch.nn.Module): 1337 def __init__(self): 1338 super().__init__() 1339 1340 def forward(self, x, y): 1341 out = torch.le(x, y) 1342 return out 1343 1344 module = LtModule() 1345 input_tensor = torch.randn(2, 3, 4) 1346 model_inputs = (input_tensor, input_tensor) 1347 1348 self.lower_and_test_with_partitioner( 1349 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1350 ) 1351 1352 def test_mps_backend_lt_scalar(self): 1353 class LtModule(torch.nn.Module): 1354 def __init__(self): 1355 super().__init__() 1356 1357 def forward(self, x): 1358 out = torch.lt(x, 1.0) 1359 return out 1360 1361 module = LtModule() 1362 model_inputs = (torch.randn(2, 3, 4),) 1363 1364 self.lower_and_test_with_partitioner( 1365 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1366 ) 1367 1368 @torch.inference_mode() # TODO Use for capturing. 1369 def test_mps_backend_linear(self): 1370 in_size = 2 1371 input_size = 3 1372 output_size = 4 1373 linear = torch.nn.Linear(input_size, output_size).eval() 1374 example_input = (torch.randn(in_size, input_size),) 1375 1376 self.lower_and_test_with_partitioner( 1377 linear, example_input, func_name=inspect.stack()[0].function[5:] 1378 ) 1379 1380 def test_mps_backend_glu(self): 1381 class GLUModule(torch.nn.Module): 1382 def __init__(self, dim): 1383 super().__init__() 1384 self.glu = torch.nn.GLU(dim=dim) 1385 1386 def forward(self, x): 1387 return self.glu(x) 1388 1389 shape = (4, 2) 1390 for dim in list(range(len(shape))) + [-1]: 1391 model_inputs = (torch.rand(shape),) 1392 glu_module = GLUModule(dim) 1393 self.lower_and_test_with_partitioner( 1394 glu_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1395 ) 1396 1397 def test_mps_backend_softmax(self): 1398 class SoftMaxModule(torch.nn.Module): 1399 def __init__(self, dim): 1400 super().__init__() 1401 self.softmax = torch.nn.Softmax(dim=dim) 1402 1403 def forward(self, x): 1404 return self.softmax(x) 1405 1406 shape = (3, 5, 7) 1407 for dim in list(range(len(shape))) + [-1]: 1408 model_inputs = (torch.rand(shape),) 1409 softmax_module = SoftMaxModule(dim) 1410 self.lower_and_test_with_partitioner( 1411 softmax_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1412 ) 1413 1414 def test_mps_backend_log_softmax(self): 1415 class LogSoftMaxModule(torch.nn.Module): 1416 def __init__(self, dim): 1417 super().__init__() 1418 self.logsoftmax = torch.nn.LogSoftmax(dim=dim) 1419 1420 def forward(self, x): 1421 return self.logsoftmax(x) 1422 1423 shape = (3, 5, 7) 1424 for dim in list(range(len(shape))) + [-1]: 1425 model_inputs = (torch.rand(shape),) 1426 logsoftmax_module = LogSoftMaxModule(dim) 1427 1428 self.lower_and_test_with_partitioner( 1429 logsoftmax_module, 1430 model_inputs, 1431 func_name=inspect.stack()[0].function[5:], 1432 ) 1433 1434 def test_mps_backend_hardtanh(self): 1435 class HardTanhModule(torch.nn.Module): 1436 def __init__(self, min_val=-1.0, max_val=1.0): 1437 super().__init__() 1438 self.hardtanh = torch.nn.Hardtanh(min_val, max_val) 1439 1440 def forward(self, x): 1441 return self.hardtanh(x) 1442 1443 inputs = [torch.randn(2, 3, 4), torch.randn(7, 5, 2), torch.randn(2, 9)] 1444 for test_input in inputs: 1445 hardtanh_model = HardTanhModule() 1446 self.lower_and_test_with_partitioner( 1447 hardtanh_model, (test_input,), func_name=inspect.stack()[0].function[5:] 1448 ) 1449 1450 for test_input in inputs: 1451 hardtanh_model = HardTanhModule(-2, 2) 1452 self.lower_and_test_with_partitioner( 1453 hardtanh_model, (test_input,), func_name=inspect.stack()[0].function[5:] 1454 ) 1455 1456 def test_mps_backend_Relu(self): 1457 class ReluModule(torch.nn.Module): 1458 def __init__(self): 1459 super().__init__() 1460 self.relu = torch.nn.ReLU() 1461 1462 def forward(self, x): 1463 return self.relu(x) 1464 1465 example_input = torch.randn(2, 3, 4) 1466 self.lower_and_test_with_partitioner( 1467 ReluModule(), (example_input,), func_name=inspect.stack()[0].function[5:] 1468 ) 1469 1470 def test_mps_backend_GELU(self): 1471 class GELUModule(torch.nn.Module): 1472 def __init__(self): 1473 super().__init__() 1474 self.gelu = torch.nn.GELU() 1475 self.gelu_tanh = torch.nn.GELU(approximate="tanh") 1476 1477 def forward(self, x): 1478 return self.gelu(x) 1479 # MPS TODO: MPS Gelu tanh fails 1480 # return self.gelu_tanh(y) 1481 1482 example_input = torch.randn(2, 3, 4) 1483 self.lower_and_test_with_partitioner( 1484 GELUModule(), (example_input,), func_name=inspect.stack()[0].function[5:] 1485 ) 1486 1487 def test_mps_backend_leaky_Relu(self): 1488 class LeakyReluModule(torch.nn.Module): 1489 def __init__(self): 1490 super().__init__() 1491 self.leaky_relu = torch.nn.LeakyReLU() 1492 self.leaky_relu_2 = torch.nn.LeakyReLU(1.0) 1493 1494 def forward(self, x): 1495 out = self.leaky_relu(x) 1496 out = self.leaky_relu_2(out) 1497 return out 1498 1499 example_input = torch.randn(2, 3, 4) 1500 self.lower_and_test_with_partitioner( 1501 LeakyReluModule(), 1502 (example_input,), 1503 func_name=inspect.stack()[0].function[5:], 1504 ) 1505 1506 def test_mps_backend_sigmoid(self): 1507 class SigmoidModule(torch.nn.Module): 1508 def __init__(self): 1509 super().__init__() 1510 self.sigmoid = torch.nn.Sigmoid() 1511 1512 def forward(self, x): 1513 return self.sigmoid(x) 1514 1515 model_inputs = (torch.rand(7, 5, 3),) 1516 sigmoid_module = SigmoidModule() 1517 self.lower_and_test_with_partitioner( 1518 sigmoid_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1519 ) 1520 1521 def test_mps_backend_constant_pad_nd(self): 1522 class PadModule(torch.nn.Module): 1523 def __init__(self): 1524 super().__init__() 1525 self.constant_pad = torch.nn.ConstantPad2d((1, 2), 0) 1526 1527 def forward(self, x): 1528 return self.constant_pad(x) 1529 1530 model_inputs = (torch.rand(1, 2, 3, 4),) 1531 pad_module = PadModule() 1532 self.lower_and_test_with_partitioner( 1533 pad_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1534 ) 1535 1536 def test_mps_backend_index_select(self): 1537 class IndexSelectModule(torch.nn.Module): 1538 def __init__(self): 1539 super().__init__() 1540 1541 def forward(self, input, index): 1542 return torch.index_select(input, dim=2, index=index) 1543 1544 model_inputs = (torch.rand(2, 8, 4, 5), torch.tensor([3, 0, 1])) 1545 module = IndexSelectModule() 1546 self.lower_and_test_with_partitioner( 1547 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1548 ) 1549 1550 def test_mps_backend_empty(self): 1551 class EmptyModule(torch.nn.Module): 1552 def __init__(self): 1553 super().__init__() 1554 1555 def forward(self): 1556 return torch.empty((3, 4, 5), dtype=torch.float32) 1557 1558 self.lower_and_test_with_partitioner( 1559 EmptyModule(), (), func_name=inspect.stack()[0].function[5:] 1560 ) 1561 1562 def test_mps_backend_static_constant_pad(self): 1563 class StaticConstantPadModule(torch.nn.Module): 1564 def __init__(self): 1565 super().__init__() 1566 1567 def forward(self, x, y, z): 1568 pad_6 = (1, 2, 3, 4, 5, 6) 1569 pad_4 = (1, 2, 3, 4) 1570 pad_2 = (1, 2) 1571 a = torch.nn.functional.pad( 1572 input=x, 1573 pad=pad_6, 1574 mode="constant", 1575 value=2.3, 1576 ) 1577 b = torch.nn.functional.pad( 1578 input=x, 1579 pad=pad_4, 1580 mode="constant", 1581 value=1.3, 1582 ) 1583 c = torch.nn.functional.pad( 1584 input=x, 1585 pad=pad_2, 1586 mode="constant", 1587 value=2.1, 1588 ) 1589 d = torch.nn.functional.pad( 1590 input=y, 1591 pad=pad_6, 1592 mode="constant", 1593 value=2.7, 1594 ) 1595 e = torch.nn.functional.pad( 1596 input=y, 1597 pad=pad_4, 1598 mode="constant", 1599 value=1.9, 1600 ) 1601 f = torch.nn.functional.pad( 1602 input=y, 1603 pad=pad_2, 1604 mode="constant", 1605 value=3.1, 1606 ) 1607 g = torch.nn.functional.pad( 1608 input=z, 1609 pad=pad_4, 1610 mode="constant", 1611 value=2.9, 1612 ) 1613 h = torch.nn.functional.pad( 1614 input=z, 1615 pad=pad_2, 1616 mode="constant", 1617 value=1.2, 1618 ) 1619 return (a, b, c, d, e, f, g, h) 1620 1621 example_inputs = ( 1622 torch.randn(size=(5, 4, 3, 2)), 1623 torch.randn(size=(5, 3, 2)), 1624 torch.randn(size=(4, 3)), 1625 ) 1626 self.lower_and_test_with_partitioner( 1627 StaticConstantPadModule(), 1628 example_inputs, 1629 func_name=inspect.stack()[0].function[5:], 1630 ) 1631 1632 def test_mps_clamp_min_max(self): 1633 class Clamp(torch.nn.Module): 1634 def __init__(self, min_val, max_val): 1635 super().__init__() 1636 self.clamp = torch.clamp 1637 self.min_val = min_val 1638 self.max_val = max_val 1639 1640 def forward(self, *x): 1641 out1 = self.clamp(x[0], min=-0.5, max=0.5) 1642 out2 = self.clamp(x[0], min=-5, max=5) 1643 return out1, out2 1644 1645 model_inputs = ( 1646 torch.randn(1, 4, 122, 122) * 2, 1647 torch.randint(-100, 100, (1, 4, 15, 20)), 1648 ) 1649 module = Clamp(-0.5, 0.5) 1650 self.lower_and_test_with_partitioner( 1651 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1652 ) 1653 1654 def test_mps_clamp_min(self): 1655 class Clamp(torch.nn.Module): 1656 def __init__(self, min_val, max_val): 1657 super().__init__() 1658 self.clamp = torch.clamp 1659 self.min_val = min_val 1660 self.max_val = max_val 1661 1662 def forward(self, x): 1663 return self.clamp(x, min=self.min_val, max=self.max_val) 1664 1665 model_inputs = (torch.randn(1, 4, 122, 122) * 2,) 1666 module = Clamp(-0.5, None) 1667 self.lower_and_test_with_partitioner( 1668 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1669 ) 1670 1671 def test_mps_clamp_max(self): 1672 class Clamp(torch.nn.Module): 1673 def __init__(self, min_val, max_val): 1674 super().__init__() 1675 self.clamp = torch.clamp 1676 self.min_val = min_val 1677 self.max_val = max_val 1678 1679 def forward(self, x): 1680 return self.clamp(x, min=self.min_val, max=self.max_val) 1681 1682 model_inputs = (torch.randn(1, 4, 122, 122) * 2,) 1683 module = Clamp(None, 0.5) 1684 self.lower_and_test_with_partitioner( 1685 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1686 ) 1687 1688 def test_mps_backend_maxpool2d_default(self): 1689 class MaxPool2dModule(torch.nn.Module): 1690 def __init__( 1691 self, 1692 kernel_size=3, 1693 stride=1, 1694 padding=3, 1695 dilation=1, 1696 ): 1697 super().__init__() 1698 self.max_pool2d_module = torch.nn.MaxPool2d( 1699 kernel_size=kernel_size, 1700 stride=stride, 1701 padding=padding, 1702 dilation=dilation, 1703 ) 1704 1705 def forward(self, x): 1706 return self.max_pool2d_module(x) 1707 1708 maxpool2d_module = MaxPool2dModule(3, 1, 0, 1) 1709 model_inputs = (torch.randn(4, 3, 24, 24),) 1710 1711 self.lower_and_test_with_partitioner( 1712 maxpool2d_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1713 ) 1714 1715 def test_mps_backend_maxpool2d_unsupported(self): 1716 class MaxPool2dModule(torch.nn.Module): 1717 def __init__( 1718 self, 1719 kernel_size=3, 1720 stride=1, 1721 padding=0, 1722 dilation=1, 1723 ): 1724 super().__init__() 1725 self.max_pool2d_module = torch.nn.MaxPool2d( 1726 kernel_size=kernel_size, 1727 stride=stride, 1728 padding=padding, 1729 dilation=dilation, 1730 return_indices=True, 1731 ) 1732 1733 def forward(self, x): 1734 return self.max_pool2d_module(x)[1] 1735 1736 maxpool2d_module = MaxPool2dModule(3, 1, 0, 1) 1737 model_inputs = (torch.randn(4, 3, 24, 24),) 1738 1739 self.lower_and_test_with_partitioner( 1740 maxpool2d_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1741 ) 1742 1743 def test_mps_backend_max_dim_vals(self): 1744 class MaxModule(torch.nn.Module): 1745 def __init__( 1746 self, 1747 ): 1748 super().__init__() 1749 1750 def forward(self, x): 1751 max_vals, _ = torch.max(x, dim=3, keepdim=True) 1752 return max_vals 1753 1754 model_inputs = (torch.randn(16, 3, 12, 12),) 1755 max_dim_module = MaxModule() 1756 1757 self.lower_and_test_with_partitioner( 1758 max_dim_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1759 ) 1760 1761 def test_mps_backend_max_dim(self): 1762 class MaxModule(torch.nn.Module): 1763 def __init__( 1764 self, 1765 ): 1766 super().__init__() 1767 1768 def forward(self, x): 1769 x = torch.add(x, x) 1770 max_values_1, max_indices_1 = torch.max(x, dim=2, keepdim=True) 1771 max_values_2, max_indices_2 = torch.max(x, dim=3, keepdim=True) 1772 return (max_values_1, max_indices_1, max_values_2, max_indices_2) 1773 1774 model_inputs = (torch.randn(16, 3, 12, 12),) 1775 max_dim_module = MaxModule() 1776 1777 self.lower_and_test_with_partitioner( 1778 max_dim_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1779 ) 1780 1781 def test_mps_backend_multiply(self): 1782 class MulModule(torch.nn.Module): 1783 def __init__( 1784 self, 1785 ): 1786 super().__init__() 1787 self.mul = torch.mul 1788 1789 def forward(self, x, y): 1790 return self.mul(x, y) 1791 1792 mul_module = MulModule() 1793 model_inputs = ( 1794 torch.randn((1, 8)), 1795 torch.randn((8, 1)), 1796 ) 1797 1798 self.lower_and_test_with_partitioner( 1799 mul_module, model_inputs, func_name=inspect.stack()[0].function[5:] 1800 ) 1801 1802 def test_mps_backend_sub(self): 1803 class Sub(torch.nn.Module): 1804 def __init__(self): 1805 super().__init__() 1806 self.sub = torch.sub 1807 1808 def forward(self, x, y): 1809 return self.sub(x, y) 1810 1811 module = Sub() 1812 M = torch.randn(2, 3) 1813 N = torch.randn(2, 3) 1814 model_inputs = ( 1815 M, 1816 N, 1817 ) 1818 self.lower_and_test_with_partitioner( 1819 module, model_inputs, func_name=inspect.stack()[0].function[5:] 1820 ) 1821 1822 def test_mps_backend_clone(self): 1823 class Clone(torch.nn.Module): 1824 def forward(self, x): 1825 return torch.clone(x) 1826 1827 model_inputs = (torch.randn(1, 3, 3),) 1828 self.lower_and_test_with_partitioner( 1829 Clone(), model_inputs, func_name=inspect.stack()[0].function[5:] 1830 ) 1831 1832 def test_mps_backend_floor(self): 1833 class Floor(torch.nn.Module): 1834 def forward(self, x): 1835 return torch.floor(x) 1836 1837 model_inputs = (torch.randn(1, 3, 3),) 1838 self.lower_and_test_with_partitioner( 1839 Floor(), model_inputs, func_name=inspect.stack()[0].function[5:] 1840 ) 1841 1842 def test_mps_backend_sqrt(self): 1843 class Sqrt(torch.nn.Module): 1844 def forward(self, x): 1845 return torch.sqrt(x) 1846 1847 model_inputs = (torch.randn(1, 3, 3).abs(),) 1848 self.lower_and_test_with_partitioner( 1849 Sqrt(), model_inputs, func_name=inspect.stack()[0].function[5:] 1850 ) 1851 1852 def test_mps_backend_ceil(self): 1853 class Ceil(torch.nn.Module): 1854 def forward(self, x): 1855 return torch.ceil(x) 1856 1857 model_inputs = (torch.randn(1, 3, 3),) 1858 self.lower_and_test_with_partitioner( 1859 Ceil(), model_inputs, func_name=inspect.stack()[0].function[5:] 1860 ) 1861 1862 def test_mps_backend_hardswish(self): 1863 model_inputs = (torch.randn(1, 3, 3),) 1864 1865 class HardswishModule(torch.nn.Module): 1866 def __init__(self): 1867 super(HardswishModule, self).__init__() 1868 self.hardswish_out_of_place = torch.nn.Hardswish() 1869 self.hardswish_in_place = torch.nn.Hardswish(inplace=True) 1870 self.hardswish_functional = torch.nn.functional.hardswish 1871 1872 def forward(self, x): 1873 a = self.hardswish_out_of_place(x) 1874 a = self.hardswish_in_place(a) 1875 a = self.hardswish_functional(a) 1876 return a 1877 1878 # TODO(T158969708) 1879 self.lower_and_test_with_partitioner( 1880 HardswishModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 1881 ) 1882 1883 def test_mps_backend_leaky_relu(self): 1884 model_inputs = (torch.randn(1, 3, 3),) 1885 1886 class LeakyReLUModule(torch.nn.Module): 1887 def __init__(self): 1888 super(LeakyReLUModule, self).__init__() 1889 self.leaky_relu_out_of_place = torch.nn.LeakyReLU(negative_slope=0.2) 1890 self.leaky_relu_in_place = torch.nn.LeakyReLU( 1891 negative_slope=0.08, inplace=True 1892 ) 1893 self.leaky_relu_functional_default = torch.nn.functional.leaky_relu 1894 1895 def forward(self, x): 1896 a = self.leaky_relu_out_of_place(x) 1897 a = self.leaky_relu_in_place(a) 1898 a = self.leaky_relu_functional_default(a) 1899 return a 1900 1901 self.lower_and_test_with_partitioner( 1902 LeakyReLUModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 1903 ) 1904 1905 @unittest.skip 1906 def test_mps_channels_last_tagged_reshape_pass_output(self): 1907 op_sequences = OpSequencesAddConv2d(2, 2) 1908 op_sequences.eval() 1909 1910 example_inputs = (torch.ones(1, 1, 6, 6),) 1911 1912 self.lower_and_test_with_partitioner( 1913 op_sequences, example_inputs, func_name=inspect.stack()[0].function[5:] 1914 ) 1915 1916 def test_mps_backend_conv2d_bn_hardtanh_mean_sequence(self): 1917 """ 1918 This test makes sure that we can fuse batchnorm and hardtanh 1919 even with inserting copy nodes at some spots in the graph to change 1920 memory format 1921 """ 1922 groups = 1 1923 stride = [2, 2] 1924 padding = [1, 1] 1925 dilation = [1, 1] 1926 in_channels = 2 1927 out_channels = 1 1928 width = 8 1929 height = 8 1930 batches = 1 1931 example_inputs = (torch.randn(batches, in_channels, height, width),) 1932 1933 class TestModule(torch.nn.Module): 1934 def __init__(self): 1935 super(TestModule, self).__init__() 1936 self.conv = torch.nn.Conv2d( 1937 in_channels=in_channels, 1938 out_channels=out_channels, 1939 kernel_size=(3, 3), 1940 stride=stride, 1941 padding=padding, 1942 groups=groups, 1943 dilation=dilation, 1944 bias=True, 1945 ) 1946 self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) 1947 self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) 1948 1949 def forward(self, x): 1950 x = self.conv(x) 1951 x = self.native_batchnorm(x) 1952 x = self.hardtanh(x) 1953 x = torch.mean(x, (-1, -2), keepdim=True) 1954 return x 1955 1956 test_module = TestModule() 1957 test_module.eval() 1958 self.lower_and_test_with_partitioner( 1959 test_module, example_inputs, func_name=inspect.stack()[0].function[5:] 1960 ) 1961 1962 @unittest.expectedFailure 1963 def test_mps_backend_maximum_no_broadcast(self): 1964 model_inputs_no_broadcast = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) 1965 1966 self.lower_and_test_with_partitioner( 1967 torch.maximum, 1968 model_inputs_no_broadcast, 1969 func_name=inspect.stack()[0].function[5:], 1970 ) 1971 1972 @unittest.expectedFailure 1973 def test_mps_backend_maximum_broadcast(self): 1974 model_inputs_broadcast = (torch.randn(2, 3, 4), torch.randn(2, 1, 4)) 1975 1976 self.lower_and_test_with_partitioner( 1977 torch.maximum, 1978 model_inputs_broadcast, 1979 func_name=inspect.stack()[0].function[5:], 1980 ) 1981 1982 def test_mps_backend_negative(self): 1983 model_inputs = (torch.randn(1, 3, 3),) 1984 1985 class NegModule(torch.nn.Module): 1986 def __init__(self): 1987 super().__init__() 1988 1989 def forward(self, x): 1990 return torch.neg(x) 1991 1992 self.lower_and_test_with_partitioner( 1993 NegModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 1994 ) 1995 1996 def test_mps_backend_remainder_1(self): 1997 model_inputs = (torch.randn(1, 3, 3), torch.randn(1, 3, 3)) 1998 1999 class RemainderModule(torch.nn.Module): 2000 def __init__(self): 2001 super().__init__() 2002 2003 def forward(self, x, y): 2004 return torch.remainder(x, y) 2005 2006 self.lower_and_test_with_partitioner( 2007 RemainderModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2008 ) 2009 2010 def test_mps_backend_remainder_2(self): 2011 model_inputs = (torch.randn(1, 3, 3),) 2012 2013 class RemainderModule(torch.nn.Module): 2014 def __init__(self): 2015 super().__init__() 2016 2017 def forward(self, x): 2018 return torch.remainder(x, 0.5) 2019 2020 self.lower_and_test_with_partitioner( 2021 RemainderModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2022 ) 2023 2024 def test_mps_backend_square(self): 2025 model_inputs = (torch.randn(1, 3, 3),) 2026 2027 class SquareModule(torch.nn.Module): 2028 def __init__(self): 2029 super().__init__() 2030 2031 def forward(self, x): 2032 return torch.square(x) 2033 2034 self.lower_and_test_with_partitioner( 2035 SquareModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2036 ) 2037 2038 def test_mps_backend_pow_1(self): 2039 model_inputs = (torch.randn(1, 3, 3),) 2040 2041 class PowModule(torch.nn.Module): 2042 def __init__(self): 2043 super().__init__() 2044 2045 def forward(self, x): 2046 return torch.pow(x, 4) 2047 2048 self.lower_and_test_with_partitioner( 2049 PowModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2050 ) 2051 2052 def test_mps_backend_pow_2(self): 2053 model_inputs = (torch.randn(1, 3, 3), torch.tensor(4)) 2054 2055 class PowModule(torch.nn.Module): 2056 def __init__(self): 2057 super().__init__() 2058 2059 def forward(self, x, y): 2060 return torch.pow(x, y) 2061 2062 self.lower_and_test_with_partitioner( 2063 PowModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2064 ) 2065 2066 def test_mps_backend_elu(self): 2067 model_inputs = (torch.randn(1, 3, 3),) 2068 2069 class ELUModule(torch.nn.Module): 2070 def __init__(self): 2071 super().__init__() 2072 2073 def forward(self, x): 2074 return torch.square(x) 2075 2076 self.lower_and_test_with_partitioner( 2077 ELUModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2078 ) 2079 2080 def test_mps_backend_avg_pool_2d_1(self): 2081 model_inputs = (torch.randn(1, 1, 10, 10),) 2082 2083 class AvgPoolModule(torch.nn.Module): 2084 def __init__(self): 2085 super().__init__() 2086 self.avgPool = torch.nn.AvgPool2d( 2087 kernel_size=(2, 2), 2088 padding=(1, 1), 2089 stride=(2, 2), 2090 count_include_pad=False, 2091 ) 2092 2093 def forward(self, x): 2094 return self.avgPool(x) 2095 2096 self.lower_and_test_with_partitioner( 2097 AvgPoolModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2098 ) 2099 2100 def test_mps_backend_avg_pool_2d_2(self): 2101 model_inputs = (torch.randn(1, 1, 10, 10),) 2102 2103 class AvgPoolModule(torch.nn.Module): 2104 def __init__(self): 2105 super().__init__() 2106 self.avgPool = torch.nn.AvgPool2d( 2107 kernel_size=(2, 2), 2108 padding=(1, 1), 2109 stride=(2, 2), 2110 count_include_pad=True, 2111 ) 2112 2113 def forward(self, x): 2114 return self.avgPool(x) 2115 2116 self.lower_and_test_with_partitioner( 2117 AvgPoolModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2118 ) 2119 2120 def test_mps_backend_avg_pool_2d_3(self): 2121 model_inputs = (torch.randn(1, 1, 10, 10),) 2122 2123 class AvgPoolModule(torch.nn.Module): 2124 def __init__(self): 2125 super().__init__() 2126 self.avgPool = torch.nn.AvgPool2d( 2127 kernel_size=(2, 2), 2128 padding=(1, 1), 2129 stride=(2, 2), 2130 count_include_pad=False, 2131 ceil_mode=True, 2132 divisor_override=4, 2133 ) 2134 2135 def forward(self, x): 2136 return self.avgPool(x) 2137 2138 self.lower_and_test_with_partitioner( 2139 AvgPoolModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2140 ) 2141 2142 def test_mps_backend_abs(self): 2143 model_inputs = (torch.randn(1, 3, 3),) 2144 2145 class AbsModule(torch.nn.Module): 2146 def forward(self, x): 2147 return torch.abs(x) 2148 2149 self.lower_and_test_with_partitioner( 2150 AbsModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2151 ) 2152 2153 def test_mps_backend_sign(self): 2154 model_inputs = (torch.randn(1, 3, 3),) 2155 2156 class SignModule(torch.nn.Module): 2157 def forward(self, x): 2158 return torch.sign(x) 2159 2160 self.lower_and_test_with_partitioner( 2161 SignModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2162 ) 2163 2164 def test_mps_backend_rsqrt(self): 2165 model_inputs = (torch.randn(1, 3, 3).abs(),) 2166 2167 class RsqrtModule(torch.nn.Module): 2168 def forward(self, x): 2169 return torch.rsqrt(x) 2170 2171 self.lower_and_test_with_partitioner( 2172 RsqrtModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2173 ) 2174 2175 def test_mps_backend_prelu(self): 2176 num_channels = 5 2177 model_inputs = (torch.randn(1, num_channels, 3, 2),) 2178 2179 class PReLUModule(torch.nn.Module): 2180 def __init__(self): 2181 super(PReLUModule, self).__init__() 2182 self.prelu = torch.nn.PReLU() 2183 self.prelu_non_default = torch.nn.PReLU( 2184 num_parameters=num_channels, init=0.2 2185 ) 2186 2187 def forward(self, x): 2188 a = self.prelu(x) 2189 a = self.prelu_non_default(a) 2190 return a 2191 2192 self.lower_and_test_with_partitioner( 2193 PReLUModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2194 ) 2195 2196 # Should fail to be partitioned since constraint (input dim) is violated 2197 self.assertRaises( 2198 Exception, 2199 self.lower_and_test_with_partitioner, 2200 torch.nn.PReLU(), 2201 (torch.randn(1, 2),), 2202 ) 2203 2204 def test_mps_backend_concatenate2(self): 2205 class Concat(torch.nn.Module): 2206 def forward(self, x, y): 2207 return torch.cat((y, x), 0) 2208 2209 self.lower_and_test_with_partitioner( 2210 Concat(), 2211 (torch.ones(4, 2, 3), torch.randn(1, 2, 3)), 2212 func_name=inspect.stack()[0].function[5:], 2213 ) 2214 2215 def test_mps_backend_concatenate3(self): 2216 class Concat(torch.nn.Module): 2217 def forward(self, x, y): 2218 return torch.concat((y, y, x), 0) 2219 2220 self.lower_and_test_with_partitioner( 2221 Concat(), 2222 (torch.ones(4, 2, 3), torch.randn(1, 2, 3)), 2223 func_name=inspect.stack()[0].function[5:], 2224 ) 2225 2226 def test_mps_backend_concatenate4(self): 2227 class Concat(torch.nn.Module): 2228 def forward(self, x, y): 2229 return torch.concatenate((y, x, y, x), 2) 2230 2231 self.lower_and_test_with_partitioner( 2232 Concat(), 2233 (torch.randn(1, 2, 3), torch.randn(1, 2, 5)), 2234 func_name=inspect.stack()[0].function[5:], 2235 ) 2236 2237 def test_mps_backend_concatenate_nhwc(self): 2238 class Concat(torch.nn.Module): 2239 def __init__(self): 2240 super().__init__() 2241 self.conv = torch.nn.Conv2d( 2242 in_channels=1, 2243 out_channels=3, 2244 kernel_size=(3, 3), 2245 padding=1, 2246 bias=False, 2247 ) 2248 2249 def forward(self, x, y): 2250 x = self.conv(x) 2251 return torch.concatenate((y, x, y, x), 1) 2252 2253 self.lower_and_test_with_partitioner( 2254 Concat(), 2255 (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)), 2256 func_name=inspect.stack()[0].function[5:], 2257 ) 2258 2259 def test_mps_backend_concatenate_nhwc2(self): 2260 class Concat(torch.nn.Module): 2261 def __init__(self): 2262 super().__init__() 2263 self.conv = torch.nn.Conv2d( 2264 in_channels=1, 2265 out_channels=3, 2266 kernel_size=(3, 3), 2267 padding=1, 2268 bias=False, 2269 ) 2270 2271 def forward(self, x, y): 2272 x = self.conv(x) 2273 y = self.conv(y) 2274 return torch.concatenate((y, x, y, x), 3) 2275 2276 self.lower_and_test_with_partitioner( 2277 Concat(), 2278 (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)), 2279 func_name=inspect.stack()[0].function[5:], 2280 ) 2281 2282 def test_mps_backend_slice_copy(self): 2283 class Slice(torch.nn.Module): 2284 def forward(self, x): 2285 return x[1:3, -2:, :-1] 2286 2287 self.lower_and_test_with_partitioner( 2288 Slice(), (torch.randn(5, 5, 5),), func_name=inspect.stack()[0].function[5:] 2289 ) 2290 2291 def test_mps_backend_slice_copy_stride_non_1(self): 2292 class Slice(torch.nn.Module): 2293 def forward(self, x): 2294 return x[:3:-1, 2:, :3] 2295 2296 self.assertRaises( 2297 Exception, 2298 self.lower_and_test_with_partitioner, 2299 Slice(), 2300 (torch.randn(5, 5, 5),), 2301 func_name=inspect.stack()[0].function[5:], 2302 ) 2303 2304 def test_mps_backend_slice_copy_dim_0(self): 2305 class Slice(torch.nn.Module): 2306 def forward(self, x): 2307 return x[-1:3, 2:, 3:3] 2308 2309 self.lower_module_and_test_output( 2310 Slice(), 2311 (torch.randn(5, 5, 5),), 2312 use_partitioner=False, 2313 func_name=inspect.stack()[0].function[5:], 2314 ) 2315 2316 def test_mps_backend_slice_copy_memory_format(self): 2317 class ConvSlice(torch.nn.Module): 2318 def __init__(self): 2319 super().__init__() 2320 self.conv = torch.nn.Conv2d( 2321 in_channels=1, 2322 out_channels=3, 2323 kernel_size=(3, 3), 2324 padding=1, 2325 bias=False, 2326 ) 2327 2328 def forward(self, x): 2329 y = self.conv(x) 2330 return y[:, :, 2:3, -2:] 2331 2332 self.lower_and_test_with_partitioner( 2333 ConvSlice(), 2334 (torch.randn(1, 1, 3, 3),), 2335 func_name=inspect.stack()[0].function[5:], 2336 ) 2337 2338 def test_mps_backend_bitwise_and(self): 2339 class BitwiseAnd(torch.nn.Module): 2340 def forward(self, x, y): 2341 return torch.bitwise_and(x, y) 2342 2343 model_inputs = ( 2344 torch.tensor([-1, -2, 3], dtype=torch.int8), 2345 torch.tensor([1, 0, 3], dtype=torch.int8), 2346 ) 2347 self.lower_and_test_with_partitioner( 2348 BitwiseAnd(), model_inputs, func_name=inspect.stack()[0].function[5:] 2349 ) 2350 2351 def test_mps_backend_bitwise_or(self): 2352 class BitwiseOr(torch.nn.Module): 2353 def forward(self, x, y): 2354 return torch.bitwise_or(x, y) 2355 2356 model_inputs = ( 2357 torch.tensor([-1, -2, 3], dtype=torch.int8), 2358 torch.tensor([1, 0, 3], dtype=torch.int8), 2359 ) 2360 self.lower_and_test_with_partitioner( 2361 BitwiseOr(), model_inputs, func_name=inspect.stack()[0].function[5:] 2362 ) 2363 2364 def test_mps_backend_bitwise_xor(self): 2365 class BitwiseXor(torch.nn.Module): 2366 def forward(self, x, y): 2367 return torch.bitwise_xor(x, y) 2368 2369 model_inputs = ( 2370 torch.tensor([True, True, False]), 2371 torch.tensor([False, True, False]), 2372 ) 2373 self.lower_and_test_with_partitioner( 2374 BitwiseXor(), model_inputs, func_name=inspect.stack()[0].function[5:] 2375 ) 2376 2377 def test_mps_backend_bitwise_not(self): 2378 class BitwiseNot(torch.nn.Module): 2379 def forward(self, x): 2380 return torch.bitwise_not(x) 2381 2382 model_inputs = (torch.tensor([-1, -2, 3], dtype=torch.int8),) 2383 self.lower_and_test_with_partitioner( 2384 BitwiseNot(), model_inputs, func_name=inspect.stack()[0].function[5:] 2385 ) 2386 2387 def test_mps_backend_bitwise_not_with_bool(self): 2388 class BitwiseNot(torch.nn.Module): 2389 def forward(self, x): 2390 return torch.bitwise_not(x) 2391 2392 model_inputs = (torch.tensor([True, True, False]),) 2393 self.lower_and_test_with_partitioner( 2394 BitwiseNot(), model_inputs, func_name=inspect.stack()[0].function[5:] 2395 ) 2396 2397 def test_mps_backend_bitwise_with_scalar(self): 2398 class BitwiseScalarModule(torch.nn.Module): 2399 def __init__(self): 2400 super().__init__() 2401 self._scalar = 3 2402 2403 def forward(self, x): 2404 out1 = torch.ops.aten.bitwise_and.Scalar(x, self._scalar) 2405 return out1 2406 2407 model_inputs = (torch.tensor([-1, -2, 3], dtype=torch.int8),) 2408 self.lower_and_test_with_partitioner( 2409 BitwiseScalarModule(), 2410 model_inputs, 2411 func_name=inspect.stack()[0].function[5:], 2412 ) 2413 2414 def test_mps_backend_arange(self): 2415 class ArangeModule(torch.nn.Module): 2416 def __init__(self): 2417 super().__init__() 2418 self._begin = 2.5 2419 self._end = 5 2420 self._step = 0.5 2421 2422 def forward(self): 2423 out1 = torch.arange(end=self._end) 2424 out2 = torch.arange(start=self._begin, end=self._end, step=self._step) 2425 return out1 + out2 2426 2427 self.lower_and_test_with_partitioner( 2428 ArangeModule(), (), func_name=inspect.stack()[0].function[5:] 2429 ) 2430 2431 def test_mps_backend_where(self): 2432 class Where(torch.nn.Module): 2433 def forward(self, cond, x, y): 2434 return torch.where(cond, x, y) 2435 2436 x = torch.randn(3, 2) 2437 y = torch.ones(3, 2) 2438 cond = x > 0 2439 module_inputs = (cond, x, y) 2440 self.lower_and_test_with_partitioner( 2441 Where(), module_inputs, func_name=inspect.stack()[0].function[5:] 2442 ) 2443 2444 def test_mps_backend_scalar_tensor(self): 2445 class ScalarTensorModule(torch.nn.Module): 2446 def __init__(self): 2447 super().__init__() 2448 self._scalar = 3.0 2449 self._bool = True 2450 2451 def forward(self): 2452 out1 = torch.ops.aten.scalar_tensor(self._scalar) 2453 out2 = torch.ops.aten.scalar_tensor(self._scalar, dtype=torch.int32) 2454 # issue 121117206 2455 out3 = torch.ops.aten.scalar_tensor(self._bool, dtype=torch.bool) 2456 return out1 + out2 + out3 2457 2458 self.lower_and_test_with_partitioner( 2459 ScalarTensorModule(), (), func_name=inspect.stack()[0].function[5:] 2460 ) 2461 2462 def test_mps_backend_tril(self): 2463 class TrilModule(torch.nn.Module): 2464 def __init__(self): 2465 super().__init__() 2466 self._k = 1 2467 self._negK = -1 2468 2469 def forward(self, x): 2470 out1 = torch.tril(x, diagonal=self._k) 2471 out2 = torch.tril(x, diagonal=self._negK) 2472 return out1 + out2 2473 2474 model_inputs = (torch.randn(4, 6),) 2475 self.lower_and_test_with_partitioner( 2476 TrilModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2477 ) 2478 2479 def test_mps_backend_embedding(self): 2480 class EmbeddingModule(torch.nn.Module): 2481 def __init__(self): 2482 super().__init__() 2483 self._embedding = torch.nn.Embedding(10, 3) 2484 self._embedding_with_padding = torch.nn.Embedding(10, 3, padding_idx=2) 2485 2486 def forward(self, x): 2487 return self._embedding(x) + self._embedding_with_padding(x) 2488 2489 model_inputs = (torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]),) 2490 self.lower_and_test_with_partitioner( 2491 EmbeddingModule(), model_inputs, func_name=inspect.stack()[0].function[5:] 2492 ) 2493