1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import ctypes 10import unittest 11from typing import Tuple 12 13import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema 14 15import torch 16 17from executorch.backends.transforms.convert_dtype_pass import I64toI32 18 19from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner 20from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend 21 22from executorch.exir import EdgeCompileConfig 23from torch.export import Dim, export, ExportedProgram 24 25ctypes.CDLL("libvulkan.so.1") 26 27 28from executorch.exir import to_edge_transform_and_lower 29from executorch.extension.pybindings.portable_lib import ( # @manual 30 _load_for_executorch_from_buffer, 31) 32from executorch.extension.pytree import tree_flatten 33 34 35class TestBackends(unittest.TestCase): 36 _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( 37 _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. 38 ) 39 40 def assert_outputs_equal( 41 self, 42 model_output, 43 ref_output, 44 atol=1e-03, 45 rtol=1e-03, 46 first_output_only=False, 47 equal_nan=True, 48 ): 49 """ 50 Helper testing function that asserts that the model output and the reference output 51 are equal with some tolerance. Due to numerical differences between eager mode and 52 the Vulkan's backend, we relax the detal such that default absolute 53 tolerance is 1e-3. and default relative tolerance is 1e-3. 54 """ 55 56 # Compare the result from executor and eager mode direclty 57 if isinstance(ref_output, tuple) or isinstance(ref_output, list): 58 # Multiple outputs executor always returns tuple, even if there is one output 59 self.assertTrue(len(ref_output) == len(model_output)) 60 if first_output_only: 61 self.assertTrue( 62 torch.allclose( 63 model_output[0], 64 ref_output[0], 65 atol=atol, 66 rtol=rtol, 67 equal_nan=equal_nan, 68 ) 69 ) 70 else: 71 for i in range(len(ref_output)): 72 self.assertTrue( 73 torch.allclose( 74 model_output[i], 75 ref_output[i], 76 atol=atol, 77 rtol=rtol, 78 equal_nan=equal_nan, 79 ) 80 ) 81 else: 82 # If one output, eager returns tensor while executor tuple of size 1 83 self.assertTrue( 84 torch.allclose( 85 model_output[0], 86 ref_output, 87 atol=atol, 88 rtol=rtol, 89 equal_nan=equal_nan, 90 ) 91 ) 92 93 def lower_module_and_test_output( 94 self, 95 model: torch.nn.Module, 96 sample_inputs: Tuple[torch.Tensor], 97 atol=1e-03, 98 rtol=1e-01, 99 dynamic_shapes=None, 100 test_inputs=None, 101 memory_layouts=None, 102 first_output_only=False, 103 ): 104 """ 105 Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with 106 the given sample inputs. It then runs the lowered module and compares its 107 outputs with the outputs of the eager module. 108 """ 109 110 def run_test(memory_layout): 111 compile_options = { 112 "memory_layout_override": memory_layout, 113 } 114 115 # At least model should run in eager mode. 116 model.eval() 117 model(*sample_inputs) 118 119 program: ExportedProgram = export( 120 model, sample_inputs, dynamic_shapes=dynamic_shapes 121 ) 122 123 edge_program = to_edge_transform_and_lower( 124 program, 125 transform_passes=[ 126 I64toI32(self._edge_compile_config._skip_dim_order), 127 ], 128 partitioner=[VulkanPartitioner(compile_options)], 129 ) 130 executorch_program = edge_program.to_executorch() 131 132 self.assertEqual( 133 executorch_program.executorch_program.execution_plan[0].delegates[0].id, 134 VulkanBackend.__name__, 135 ) 136 137 executorch_module = _load_for_executorch_from_buffer( 138 executorch_program.buffer 139 ) 140 inputs_flattened, _ = tree_flatten(sample_inputs) 141 142 model_output = executorch_module.run_method( 143 "forward", tuple(inputs_flattened) 144 ) 145 ref_output = model(*sample_inputs) 146 147 self.assert_outputs_equal( 148 model_output, 149 ref_output, 150 atol=atol, 151 rtol=rtol, 152 first_output_only=first_output_only, 153 ) 154 155 if test_inputs is not None: 156 for test_input in test_inputs: 157 test_inputs_flattened, _ = tree_flatten(test_input) 158 model_output = executorch_module.run_method( 159 "forward", tuple(test_inputs_flattened) 160 ) 161 ref_output = model(*test_input) 162 163 self.assert_outputs_equal( 164 model_output, 165 ref_output, 166 atol=atol, 167 rtol=rtol, 168 first_output_only=first_output_only, 169 ) 170 171 memory_layouts_to_test = [ 172 vk_graph_schema.VkMemoryLayout.TENSOR_WIDTH_PACKED, 173 vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED, 174 ] 175 176 if memory_layouts is not None: 177 memory_layouts_to_test = memory_layouts 178 179 for memory_layout in memory_layouts_to_test: 180 run_test(memory_layout) 181 182 def test_vulkan_backend_add(self): 183 # This test is the simplest test by manually lowering some submodules, we can use paritioner 184 # for auto detecting lowerable parts. 185 class AddModule(torch.nn.Module): 186 def __init__(self): 187 super().__init__() 188 189 def forward(self, x, y, w): 190 z = x + y 191 z = z + x 192 z = z + x 193 z = z + w 194 z = w + z 195 z = z + 3 # test scalar broadcasting 196 return z 197 198 add_module = AddModule() 199 sample_inputs = ( 200 torch.rand(size=(2, 3), dtype=torch.float32), 201 torch.rand(size=(2, 3), dtype=torch.float32), 202 torch.rand(size=(2, 1), dtype=torch.float32), # test broadcasting 203 ) 204 205 self.lower_module_and_test_output(add_module, sample_inputs) 206 207 sample_inputs = ( 208 torch.rand(size=(4, 5, 2, 3), dtype=torch.float32), 209 torch.rand(size=(4, 5, 2, 3), dtype=torch.float32), 210 torch.rand( 211 size=(2, 3), dtype=torch.float32 212 ), # test broadcasting on packed dim 213 ) 214 215 self.lower_module_and_test_output(add_module, sample_inputs) 216 217 def test_vulkan_backend_add_int(self): 218 class AddIntModule(torch.nn.Module): 219 def __init__(self): 220 super().__init__() 221 222 def forward(self, x, y): 223 z = x + y 224 return z 225 226 add_int_module = AddIntModule() 227 sample_inputs = ( 228 torch.randint(low=-100, high=100, size=(2, 3), dtype=torch.int32), 229 torch.randint(low=-100, high=100, size=(2, 3), dtype=torch.int32), 230 ) 231 232 self.lower_module_and_test_output(add_int_module, sample_inputs) 233 234 def test_vulkan_backend_zero_dim_tensor(self): 235 class ZeroDimModule(torch.nn.Module): 236 def __init__(self): 237 super().__init__() 238 self.zero = torch.full([], 1.3, dtype=torch.float32) 239 240 def forward(self, x): 241 return x + self.zero 242 243 internal_data_module = ZeroDimModule() 244 sample_inputs = (torch.rand(size=(2, 3), dtype=torch.float32),) 245 self.lower_module_and_test_output(internal_data_module, sample_inputs) 246 247 def test_vulkan_backend_internal_data(self): 248 class InternalDataModule(torch.nn.Module): 249 def __init__(self): 250 super().__init__() 251 self.weight = torch.rand(size=(2, 3), dtype=torch.float32) 252 253 def forward(self, x, y): 254 inter1 = torch.add(x, y, alpha=2) 255 inter2 = torch.add(x, y, alpha=3.14) 256 inter3 = inter1 * self.weight 257 inter4 = inter2 * self.weight 258 return inter4 - inter3 259 260 internal_data_module = InternalDataModule() 261 sample_inputs = ( 262 torch.rand(size=(2, 3), dtype=torch.float32), 263 torch.rand(size=(2, 3), dtype=torch.float32), 264 ) 265 266 self.lower_module_and_test_output(internal_data_module, sample_inputs) 267 268 def test_vulkan_backend_sub(self): 269 class SubModule(torch.nn.Module): 270 def __init__(self): 271 super().__init__() 272 273 def forward(self, x, y): 274 z = torch.sub(x, y, alpha=2) 275 z = torch.sub(z, x, alpha=3.14) 276 z = z - x 277 return z 278 279 sub_module = SubModule() 280 sample_inputs = ( 281 torch.rand(size=(2, 3), dtype=torch.float32), 282 torch.rand(size=(2, 3), dtype=torch.float32), 283 ) 284 285 self.lower_module_and_test_output(sub_module, sample_inputs) 286 287 def test_vulkan_backend_mul(self): 288 class MulModule(torch.nn.Module): 289 def __init__(self): 290 super().__init__() 291 292 def forward(self, x, y): 293 z = x * y 294 z = z * x 295 z = z * x 296 return z 297 298 mul_module = MulModule() 299 sample_inputs = ( 300 torch.rand(size=(2, 3), dtype=torch.float32), 301 torch.rand(size=(2, 3), dtype=torch.float32), 302 ) 303 304 self.lower_module_and_test_output(mul_module, sample_inputs) 305 306 def test_vulkan_backend_div(self): 307 class DivModule(torch.nn.Module): 308 def __init__(self): 309 super().__init__() 310 311 def forward(self, x, y): 312 z = x / y 313 z = z / x 314 z = z / x 315 return z 316 317 div_module = DivModule() 318 sample_inputs = ( 319 torch.rand(size=(2, 3), dtype=torch.float32), 320 torch.rand(size=(2, 3), dtype=torch.float32), 321 ) 322 323 self.lower_module_and_test_output(div_module, sample_inputs) 324 325 def test_vulkan_backend_arithmetic(self): 326 class ArithmeticModule(torch.nn.Module): 327 def __init__(self): 328 super().__init__() 329 self.weight = torch.rand(size=(2, 3), dtype=torch.float32) 330 331 def forward(self, x, y): 332 z = x + y 333 z = z - x 334 z = z / x 335 z = z * self.weight 336 return z 337 338 arithmetic_module = ArithmeticModule() 339 sample_inputs = ( 340 torch.rand(size=(2, 3), dtype=torch.float32), 341 torch.rand(size=(2, 3), dtype=torch.float32), 342 ) 343 344 self.lower_module_and_test_output(arithmetic_module, sample_inputs) 345 346 def test_vulkan_backend_floor_div(self): 347 class FloorDivModule(torch.nn.Module): 348 def __init__(self): 349 super().__init__() 350 351 def forward(self, x, y): 352 z = x // y 353 return z 354 355 floor_div_module = FloorDivModule() 356 sample_inputs = ( 357 torch.rand(size=(2, 3), dtype=torch.float32) * 10.0, 358 torch.rand(size=(2, 3), dtype=torch.float32) + 1.0, 359 ) 360 361 # absolute tolerance is 1 because of flooring 362 self.lower_module_and_test_output( 363 floor_div_module, sample_inputs, atol=1.0 + 1e-03 364 ) 365 366 def test_vulkan_backend_pow(self): 367 class PowModule(torch.nn.Module): 368 def __init__(self): 369 super().__init__() 370 371 def forward(self, x, y): 372 z = torch.pow(x, y) 373 return z 374 375 pow_module = PowModule() 376 sample_inputs = ( 377 torch.rand(size=(2, 3), dtype=torch.float32), 378 torch.rand(size=(2, 3), dtype=torch.float32), 379 ) 380 381 self.lower_module_and_test_output(pow_module, sample_inputs) 382 383 def lower_unary_module_and_test_output(self, module): 384 batch = Dim("batch", max=8) 385 sample_inputs = (torch.randn(8, 16, 96, 92),) 386 387 dynamic_shapes = {"x": {0: batch}} 388 test_inputs = [ 389 (torch.randn(3, 14, 15, 92),), 390 (torch.randn(6, 5, 35, 89),), 391 (torch.randn(7, 9, 32, 38),), 392 ] 393 394 self.lower_module_and_test_output( 395 module, 396 sample_inputs, 397 dynamic_shapes=dynamic_shapes, 398 test_inputs=test_inputs, 399 ) 400 401 def test_vulkan_backend_clamp(self): 402 class ClampModule(torch.nn.Module): 403 def __init__(self): 404 super().__init__() 405 406 def forward(self, x): 407 return torch.clamp(x, min=-3.14) 408 409 self.lower_unary_module_and_test_output(ClampModule()) 410 411 def test_vulkan_backend_clamp_int(self): 412 class ClampModule(torch.nn.Module): 413 def __init__(self): 414 super().__init__() 415 416 def forward(self, x): 417 return torch.clamp(x, min=-3) 418 419 sample_inputs = ( 420 torch.randint(low=-100, high=100, size=(5, 5), dtype=torch.int32), 421 ) 422 423 self.lower_module_and_test_output(ClampModule(), sample_inputs) 424 425 def test_vulkan_backend_clamp_int64(self): 426 class ClampModule(torch.nn.Module): 427 def __init__(self): 428 super().__init__() 429 430 def forward(self, x): 431 return torch.clamp(x, min=-3) 432 433 sample_inputs = ( 434 torch.randint(low=-100, high=100, size=(5, 5), dtype=torch.int64), 435 ) 436 437 self.lower_module_and_test_output(ClampModule(), sample_inputs) 438 439 def test_vulkan_backend_cos(self): 440 class CosModule(torch.nn.Module): 441 def __init__(self): 442 super().__init__() 443 444 def forward(self, x): 445 return torch.cos(x) 446 447 self.lower_unary_module_and_test_output(CosModule()) 448 449 def test_vulkan_backend_hardtanh(self): 450 class HardTanHModule(torch.nn.Module): 451 def __init__(self): 452 super().__init__() 453 self.tanh = torch.nn.Hardtanh(min_val=-3.14, max_val=6.28) 454 455 def forward(self, x): 456 return self.tanh(x) 457 458 self.lower_unary_module_and_test_output(HardTanHModule()) 459 460 def test_vulkan_backend_exp(self): 461 class ExpModule(torch.nn.Module): 462 def __init__(self): 463 super().__init__() 464 465 def forward(self, x): 466 return torch.exp(x) 467 468 self.lower_unary_module_and_test_output(ExpModule()) 469 470 def test_vulkan_backend_neg(self): 471 class NegModule(torch.nn.Module): 472 def __init__(self): 473 super().__init__() 474 475 def forward(self, x): 476 return torch.neg(x) 477 478 self.lower_unary_module_and_test_output(NegModule()) 479 480 def test_vulkan_backend_sin(self): 481 class SinModule(torch.nn.Module): 482 def __init__(self): 483 super().__init__() 484 485 def forward(self, x): 486 return torch.sin(x) 487 488 self.lower_unary_module_and_test_output(SinModule()) 489 490 def test_vulkan_backend_relu(self): 491 class ReLUModule(torch.nn.Module): 492 def __init__(self): 493 super().__init__() 494 495 def forward(self, x): 496 return torch.relu(x) 497 498 self.lower_unary_module_and_test_output(ReLUModule()) 499 500 def test_vulkan_backend_sqrt(self): 501 class SqrtModule(torch.nn.Module): 502 def __init__(self): 503 super().__init__() 504 505 def forward(self, x): 506 return torch.sqrt(x) 507 508 self.lower_unary_module_and_test_output(SqrtModule()) 509 510 def test_vulkan_backend_hardshrink(self): 511 class HardshrinkModule(torch.nn.Module): 512 def __init__(self): 513 super().__init__() 514 self.hardshrink = torch.nn.Hardshrink(lambd=0.3) 515 516 def forward(self, x): 517 return self.hardshrink(x) 518 519 self.lower_unary_module_and_test_output(HardshrinkModule()) 520 521 def test_vulkan_backend_max_pool2d(self): 522 class MaxPool2dModule(torch.nn.Module): 523 def __init__(self): 524 super().__init__() 525 self.max_pool = torch.nn.MaxPool2d( 526 kernel_size=(2, 3), 527 stride=(1, 1), 528 padding=0, 529 dilation=1, 530 ceil_mode=False, 531 return_indices=True, 532 ) 533 534 def forward(self, x): 535 return self.max_pool(x) 536 537 max_pool2d_module = MaxPool2dModule() 538 sample_inputs = (torch.randn(5, 13, 55, 68),) 539 540 batch = Dim("batch", max=8) 541 dynamic_shapes = {"x": {0: batch}} 542 test_inputs = [ 543 (torch.randn(3, 14, 15, 9),), 544 (torch.randn(1, 1, 4, 6),), 545 (torch.randn(5, 10, 50, 40),), 546 ] 547 self.lower_module_and_test_output( 548 max_pool2d_module, 549 sample_inputs, 550 dynamic_shapes=dynamic_shapes, 551 test_inputs=test_inputs, 552 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 553 first_output_only=True, 554 ) 555 556 def test_vulkan_backend_avg_pool2d(self): 557 class AvgPool2dModule(torch.nn.Module): 558 def __init__(self): 559 super().__init__() 560 self.avg_pool = torch.nn.AvgPool2d( 561 kernel_size=(4, 4), 562 stride=(4, 4), 563 padding=(0, 0), 564 ceil_mode=True, 565 count_include_pad=True, 566 divisor_override=None, 567 ) 568 569 def forward(self, x): 570 return self.avg_pool(x) 571 572 avg_pool2d_module = AvgPool2dModule() 573 sample_inputs = (torch.randn(5, 13, 55, 68),) 574 575 batch = Dim("batch", max=8) 576 dynamic_shapes = {"x": {0: batch}} 577 test_inputs = [ 578 (torch.randn(3, 14, 15, 9),), 579 (torch.randn(1, 1, 4, 6),), 580 (torch.randn(5, 10, 50, 40),), 581 ] 582 self.lower_module_and_test_output( 583 avg_pool2d_module, 584 sample_inputs, 585 dynamic_shapes=dynamic_shapes, 586 test_inputs=test_inputs, 587 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 588 ) 589 590 def test_vulkan_backend_abs(self): 591 class AbsModule(torch.nn.Module): 592 def __init__(self): 593 super().__init__() 594 595 def forward(self, x): 596 return torch.abs(x) 597 598 self.lower_unary_module_and_test_output(AbsModule()) 599 600 def test_vulkan_backend_sigmoid(self): 601 class SigmoidModule(torch.nn.Module): 602 def __init__(self): 603 super().__init__() 604 605 def forward(self, x): 606 return torch.sigmoid(x) 607 608 self.lower_unary_module_and_test_output(SigmoidModule()) 609 610 def test_vulkan_backend_tanh(self): 611 class TanhModule(torch.nn.Module): 612 def __init__(self): 613 super().__init__() 614 615 def forward(self, x): 616 return torch.tanh(x) 617 618 self.lower_unary_module_and_test_output(TanhModule()) 619 620 def test_vulkan_backend_linear(self): 621 class LinearModule(torch.nn.Module): 622 def __init__(self): 623 super().__init__() 624 self.linear = torch.nn.Linear(128, 64, bias=False) 625 626 def forward(self, x): 627 return self.linear(x) 628 629 module = LinearModule() 630 sample_inputs = (torch.rand(size=(32, 128), dtype=torch.float32),) 631 batch = Dim("batch", max=32) 632 dynamic_shapes = {"x": {0: batch}} 633 634 test_inputs = [ 635 (torch.rand(15, 128),), 636 (torch.rand(6, 128),), 637 (torch.rand(30, 128),), 638 (torch.rand(20, 128),), 639 (torch.rand(19, 128),), 640 ] 641 642 self.lower_module_and_test_output( 643 module, 644 sample_inputs, 645 dynamic_shapes=dynamic_shapes, 646 test_inputs=test_inputs, 647 ) 648 649 def test_vulkan_backend_partial(self): 650 class SimpleModel(torch.nn.Module): 651 def __init__(self): 652 super().__init__() 653 self.linear = torch.nn.Linear(10, 10) 654 self.offset_1 = torch.rand(size=(2, 10), dtype=torch.float32) 655 self.offset_2 = torch.rand(size=(2, 10), dtype=torch.float32) 656 657 def forward(self, x): 658 return self.linear(x + self.offset_1) - self.offset_2 659 660 model = SimpleModel() 661 sample_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),) 662 663 self.lower_module_and_test_output(model, sample_inputs) 664 665 def test_vulkan_backend_partial_dynamic_shapes(self): 666 class SimpleModel(torch.nn.Module): 667 def __init__(self): 668 super().__init__() 669 self.branch1 = torch.nn.Sequential( 670 torch.nn.Linear(64, 64), torch.nn.ReLU() 671 ) 672 self.branch2 = torch.nn.Sequential( 673 torch.nn.Linear(128, 64), torch.nn.ReLU() 674 ) 675 self.buffer_1 = torch.ones((1, 64)) * 0.5 676 self.buffer_2 = torch.ones((1, 64)) * 1.4 677 678 def forward(self, x1, x2): 679 out1 = self.branch1(x1) 680 out2 = self.branch2(x2) 681 return (out1 + self.buffer_1 + out2) * self.buffer_2 682 683 model = SimpleModel() 684 sample_inputs = (torch.randn(32, 64), torch.randn(32, 128)) 685 batch = Dim("batch", max=32) 686 dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} 687 688 test_inputs = [ 689 (torch.randn(15, 64), torch.randn(15, 128)), 690 (torch.randn(6, 64), torch.randn(6, 128)), 691 (torch.randn(30, 64), torch.randn(30, 128)), 692 (torch.randn(20, 64), torch.randn(20, 128)), 693 (torch.randn(19, 64), torch.randn(19, 128)), 694 ] 695 696 self.lower_module_and_test_output( 697 model, sample_inputs, dynamic_shapes=dynamic_shapes, test_inputs=test_inputs 698 ) 699 700 def test_vulkan_backend_matmul(self): 701 class MatMulModule(torch.nn.Module): 702 def __init__(self): 703 super().__init__() 704 self.weight = torch.ones(size=(63, 22), dtype=torch.float32) 705 706 def forward(self, x): 707 return torch.matmul(x, self.weight) 708 709 module = MatMulModule() 710 sample_inputs = (torch.ones(size=(31, 63), dtype=torch.float32),) 711 712 self.lower_module_and_test_output(module, sample_inputs) 713 714 def test_vulkan_backend_bmm(self): 715 class BMMModule(torch.nn.Module): 716 def __init__(self): 717 super().__init__() 718 self.weight = torch.randn(size=(4, 4, 5), dtype=torch.float32) 719 720 def forward(self, x): 721 return torch.bmm(x, self.weight) 722 723 module = BMMModule() 724 sample_inputs = (torch.randn(size=(4, 3, 4), dtype=torch.float32),) 725 726 self.lower_module_and_test_output(module, sample_inputs) 727 728 @unittest.skip( 729 "Reduce shader does not support multiple reduction axes at the moment" 730 ) 731 def test_vulkan_backend_sum_dim_list(self): 732 class SumModule(torch.nn.Module): 733 def __init__(self): 734 super().__init__() 735 736 def forward(self, x): 737 x = torch.sum(x, (0, -1), keepdim=True) 738 x = torch.sum(x, 2, keepdim=False) 739 return x 740 741 module = SumModule() 742 sample_inputs = (torch.ones(size=(3, 2, 7, 5), dtype=torch.float32),) 743 744 self.lower_module_and_test_output( 745 module, 746 sample_inputs, 747 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 748 ) 749 750 @unittest.skip( 751 "Reduce shader does not support multiple reduction axes at the moment" 752 ) 753 def test_vulkan_backend_sum(self): 754 class SumModule(torch.nn.Module): 755 def __init__(self): 756 super().__init__() 757 758 def forward(self, x): 759 x = torch.sum(x, (), keepdim=True) 760 x = torch.sum(x) 761 return x 762 763 module = SumModule() 764 sample_inputs = (torch.rand(size=(3, 2, 7, 5), dtype=torch.float32),) 765 766 self.lower_module_and_test_output( 767 module, 768 sample_inputs, 769 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 770 ) 771 772 def test_vulkan_backend_conv2d(self): 773 class Conv2dModule(torch.nn.Module): 774 def __init__(self): 775 super().__init__() 776 self.conv = torch.nn.Conv2d( 777 in_channels=6, 778 out_channels=8, 779 kernel_size=(3, 3), 780 padding=(2, 3), 781 stride=(1, 2), 782 dilation=1, 783 groups=1, 784 bias=True, 785 ) 786 787 def forward(self, x): 788 return self.conv(x) 789 790 conv2d_module = Conv2dModule() 791 sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),) 792 793 self.lower_module_and_test_output( 794 conv2d_module, 795 sample_inputs, 796 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 797 ) 798 799 def test_vulkan_backend_conv_transpose2d(self): 800 class ConvTranspose2dModule(torch.nn.Module): 801 def __init__(self): 802 super().__init__() 803 self.conv = torch.nn.ConvTranspose2d( 804 in_channels=6, 805 out_channels=8, 806 kernel_size=(3, 3), 807 padding=(2, 3), 808 stride=(1, 2), 809 output_padding=(0, 1), 810 dilation=1, 811 groups=1, 812 bias=True, 813 ) 814 815 def forward(self, x): 816 return self.conv(x) 817 818 conv_transpose2d_module = ConvTranspose2dModule() 819 sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),) 820 821 self.lower_module_and_test_output( 822 conv_transpose2d_module, 823 sample_inputs, 824 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 825 ) 826 827 def test_vulkan_backend_conv2d_dw(self): 828 class Conv2dModule(torch.nn.Module): 829 def __init__(self): 830 super().__init__() 831 self.conv = torch.nn.Conv2d( 832 in_channels=8, 833 out_channels=8, 834 kernel_size=3, 835 padding=1, 836 groups=8, 837 bias=True, 838 ) 839 840 def forward(self, x): 841 return self.conv(x) 842 843 conv2d_module = Conv2dModule() 844 sample_inputs = (torch.randn(size=(1, 8, 72, 96), dtype=torch.float32),) 845 846 self.lower_module_and_test_output( 847 conv2d_module, 848 sample_inputs, 849 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 850 ) 851 852 def test_vulkan_backend_conv2d_pw(self): 853 class Conv2dModule(torch.nn.Module): 854 def __init__(self): 855 super().__init__() 856 self.conv = torch.nn.Conv2d( 857 in_channels=8, 858 out_channels=8, 859 kernel_size=1, 860 padding=1, 861 groups=1, 862 bias=True, 863 ) 864 865 def forward(self, x): 866 return self.conv(x) 867 868 conv2d_module = Conv2dModule() 869 sample_inputs = (torch.randn(size=(1, 8, 72, 96), dtype=torch.float32),) 870 871 self.lower_module_and_test_output( 872 conv2d_module, 873 sample_inputs, 874 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 875 ) 876 877 def test_vulkan_backend_conv2d_bias_false(self): 878 class Conv2dModule(torch.nn.Module): 879 def __init__(self): 880 super().__init__() 881 self.conv = torch.nn.Conv2d( 882 in_channels=6, 883 out_channels=8, 884 kernel_size=(3, 3), 885 padding=(2, 3), 886 stride=(1, 2), 887 dilation=1, 888 groups=1, 889 bias=False, 890 ) 891 892 def forward(self, x): 893 return self.conv(x) 894 895 conv2d_module = Conv2dModule() 896 sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),) 897 898 self.lower_module_and_test_output( 899 conv2d_module, 900 sample_inputs, 901 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 902 ) 903 904 def test_vulkan_backend_conv1d(self): 905 class Conv1dModule(torch.nn.Module): 906 def __init__(self): 907 super().__init__() 908 self.conv = torch.nn.Conv1d( 909 in_channels=20, 910 out_channels=10, 911 kernel_size=6, 912 stride=5, 913 padding=5, 914 dilation=3, 915 groups=5, 916 bias=True, 917 ) 918 919 def forward(self, x): 920 return self.conv(x) 921 922 conv1d_module = Conv1dModule() 923 sample_inputs = (torch.randn(size=(3, 20, 30), dtype=torch.float32),) 924 925 self.lower_module_and_test_output( 926 conv1d_module, 927 sample_inputs, 928 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 929 ) 930 931 def test_vulkan_backend_conv1d_bias_false(self): 932 class Conv1dModule(torch.nn.Module): 933 def __init__(self): 934 super().__init__() 935 self.conv = torch.nn.Conv1d( 936 in_channels=6, 937 out_channels=6, 938 kernel_size=3, 939 groups=6, 940 bias=False, 941 ) 942 943 def forward(self, x): 944 return self.conv(x) 945 946 conv1d_module = Conv1dModule() 947 sample_inputs = (torch.randn(size=(1, 6, 7), dtype=torch.float32),) 948 949 self.lower_module_and_test_output( 950 conv1d_module, 951 sample_inputs, 952 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 953 ) 954 955 def test_vulkan_backend_native_layer_norm(self): 956 class NativeLayerNormModule(torch.nn.Module): 957 def __init__(self): 958 super().__init__() 959 self.layer_norm = torch.nn.LayerNorm(5) 960 961 def forward(self, x): 962 return self.layer_norm(x) 963 964 sample_inputs = (torch.randn(size=(3, 4, 5), dtype=torch.float32),) 965 966 self.lower_module_and_test_output( 967 NativeLayerNormModule(), 968 sample_inputs, 969 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 970 ) 971 972 def test_vulkan_backend_batch_norm(self): 973 class BatchNormModule(torch.nn.Module): 974 def __init__(self): 975 super().__init__() 976 self.bn = torch.nn.BatchNorm2d(num_features=3) 977 978 def forward(self, x): 979 return self.bn(x) 980 981 sample_inputs = (torch.randn(size=(4, 3, 2, 5), dtype=torch.float32),) 982 983 self.lower_module_and_test_output( 984 BatchNormModule(), 985 sample_inputs, 986 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 987 ) 988 989 def test_vulkan_backend_full(self): 990 class FullModule(torch.nn.Module): 991 def __init__(self): 992 super().__init__() 993 994 def forward(self, x): 995 return torch.full(x.shape, 42.0) 996 997 class ZerosModule(torch.nn.Module): 998 def __init__(self): 999 super().__init__() 1000 1001 def forward(self, x): 1002 return torch.zeros(x.shape) 1003 1004 class OnesModule(torch.nn.Module): 1005 def __init__(self): 1006 super().__init__() 1007 1008 def forward(self, x): 1009 return torch.ones(x.shape) 1010 1011 sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),) 1012 1013 self.lower_module_and_test_output( 1014 FullModule(), 1015 sample_inputs, 1016 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1017 ) 1018 1019 self.lower_module_and_test_output( 1020 ZerosModule(), 1021 sample_inputs, 1022 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1023 ) 1024 1025 self.lower_module_and_test_output( 1026 OnesModule(), 1027 sample_inputs, 1028 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1029 ) 1030 1031 def test_vulkan_backend_full_like(self): 1032 class FullLikeModule(torch.nn.Module): 1033 def __init__(self): 1034 super().__init__() 1035 1036 def forward(self, x): 1037 return torch.full_like(x, 42.0) 1038 1039 class ZerosLikeModule(torch.nn.Module): 1040 def __init__(self): 1041 super().__init__() 1042 1043 def forward(self, x): 1044 return torch.zeros_like(x) 1045 1046 class OnesLikeModule(torch.nn.Module): 1047 def __init__(self): 1048 super().__init__() 1049 1050 def forward(self, x): 1051 return torch.ones_like(x) 1052 1053 sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),) 1054 1055 self.lower_module_and_test_output( 1056 FullLikeModule(), 1057 sample_inputs, 1058 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1059 ) 1060 1061 self.lower_module_and_test_output( 1062 ZerosLikeModule(), 1063 sample_inputs, 1064 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1065 ) 1066 1067 self.lower_module_and_test_output( 1068 OnesLikeModule(), 1069 sample_inputs, 1070 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1071 ) 1072 1073 def test_vulkan_backend_upsample_nearest2d(self): 1074 class UpsampleNearest2d(torch.nn.Module): 1075 def __init__(self): 1076 super().__init__() 1077 self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest") 1078 1079 def forward(self, x): 1080 return self.upsample(x) 1081 1082 sample_inputs = (torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2),) 1083 1084 self.lower_module_and_test_output( 1085 UpsampleNearest2d(), 1086 sample_inputs, 1087 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1088 ) 1089 1090 def test_vulkan_backend_minimum(self): 1091 class MinimumModule(torch.nn.Module): 1092 def __init__(self): 1093 super().__init__() 1094 1095 def forward(self, x, y): 1096 return torch.minimum(x, y) 1097 1098 sample_inputs = ( 1099 torch.rand(size=(3, 5, 6, 4), dtype=torch.float32), 1100 torch.rand(size=(6, 4), dtype=torch.float32), 1101 ) 1102 1103 self.lower_module_and_test_output( 1104 MinimumModule(), 1105 sample_inputs, 1106 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1107 ) 1108 1109 def test_vulkan_backend_reshape(self): 1110 class ReshapeModule(torch.nn.Module): 1111 def __init__(self): 1112 super().__init__() 1113 1114 def forward(self, x): 1115 return torch.reshape(x, [-1, x.size(-1)]) 1116 1117 sample_inputs = (torch.randn(size=(5, 3, 4), dtype=torch.float32),) 1118 1119 self.lower_module_and_test_output( 1120 ReshapeModule(), 1121 sample_inputs, 1122 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1123 ) 1124 1125 def test_vulkan_backend_view(self): 1126 class ViewModule(torch.nn.Module): 1127 def __init__(self): 1128 super().__init__() 1129 1130 def forward(self, x): 1131 return x.view([-1, x.size(-1)]) 1132 1133 sample_inputs = (torch.randn(size=(3, 2, 3, 4), dtype=torch.float32),) 1134 1135 self.lower_module_and_test_output( 1136 ViewModule(), 1137 sample_inputs, 1138 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1139 ) 1140 1141 def test_vulkan_backend_view_int(self): 1142 class ViewModule(torch.nn.Module): 1143 def __init__(self): 1144 super().__init__() 1145 1146 def forward(self, x): 1147 return x.view([-1, x.size(-1)]) 1148 1149 sample_inputs = (torch.randint(size=(3, 6, 2, 7), high=100, dtype=torch.int32),) 1150 1151 self.lower_module_and_test_output( 1152 ViewModule(), 1153 sample_inputs, 1154 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1155 ) 1156 1157 def test_vulkan_backend_unsqueeze(self): 1158 class UnsqueezeModule(torch.nn.Module): 1159 def __init__(self): 1160 super().__init__() 1161 1162 def forward(self, x): 1163 x = torch.unsqueeze(x, 1) 1164 x = torch.unsqueeze(x, 0) 1165 return x 1166 1167 sample_inputs = (torch.randn(size=(3,), dtype=torch.float32),) 1168 1169 self.lower_module_and_test_output( 1170 UnsqueezeModule(), 1171 sample_inputs, 1172 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1173 ) 1174 1175 def test_vulkan_backend_squeeze(self): 1176 class SqueezeModule(torch.nn.Module): 1177 def __init__(self): 1178 super().__init__() 1179 1180 def forward(self, x): 1181 return torch.squeeze(x, 0) 1182 1183 sample_inputs = (torch.randn(size=(1, 2, 2, 1), dtype=torch.float32),) 1184 1185 self.lower_module_and_test_output( 1186 SqueezeModule(), 1187 sample_inputs, 1188 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1189 ) 1190 1191 def test_vulkan_backend_select(self): 1192 class SelectModule(torch.nn.Module): 1193 def __init__(self): 1194 super().__init__() 1195 1196 def forward(self, x): 1197 return x[0][3] 1198 1199 sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),) 1200 1201 self.lower_module_and_test_output( 1202 SelectModule(), 1203 sample_inputs, 1204 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1205 ) 1206 1207 def test_vulkan_backend_permute_copy(self): 1208 class PermuteModule(torch.nn.Module): 1209 def __init__(self): 1210 super().__init__() 1211 1212 def forward(self, x): 1213 return torch.permute(x, [3, 0, 2, 1]) 1214 1215 sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),) 1216 1217 self.lower_module_and_test_output( 1218 PermuteModule(), 1219 sample_inputs, 1220 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1221 ) 1222 1223 def test_vulkan_backend_permute_copy_int(self): 1224 class PermuteModule(torch.nn.Module): 1225 def __init__(self): 1226 super().__init__() 1227 1228 def forward(self, x): 1229 return torch.permute(x, [3, 0, 2, 1]) 1230 1231 sample_inputs = (torch.randint(size=(3, 6, 2, 7), high=100, dtype=torch.int32),) 1232 1233 self.lower_module_and_test_output( 1234 PermuteModule(), 1235 sample_inputs, 1236 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1237 ) 1238 1239 def test_vulkan_backend_cat(self): 1240 class TestModule(torch.nn.Module): 1241 def __init__(self): 1242 super().__init__() 1243 1244 def forward(self, x, y, z, w): 1245 return torch.cat([x, y, z, w], dim=1) 1246 1247 sample_inputs = ( 1248 torch.randn(size=(3, 6, 2, 7), dtype=torch.float32), 1249 torch.randn(size=(3, 1, 2, 7), dtype=torch.float32), 1250 torch.randn(size=(3, 9, 2, 7), dtype=torch.float32), 1251 torch.randn(size=(3, 3, 2, 7), dtype=torch.float32), 1252 ) 1253 1254 self.lower_module_and_test_output( 1255 TestModule(), 1256 sample_inputs, 1257 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1258 ) 1259 1260 def test_vulkan_backend_cat_with_zero_size(self): 1261 class TestModule(torch.nn.Module): 1262 def __init__(self): 1263 super().__init__() 1264 1265 def forward(self, x, y, z, w): 1266 return torch.cat([x, y, z, w], dim=1) 1267 1268 sample_inputs = ( 1269 torch.randn(size=(3, 6, 2, 7), dtype=torch.float32), 1270 torch.randn(size=(3, 0, 2, 7), dtype=torch.float32), 1271 torch.randn(size=(3, 0, 2, 7), dtype=torch.float32), 1272 torch.randn(size=(3, 3, 2, 7), dtype=torch.float32), 1273 ) 1274 1275 self.lower_module_and_test_output( 1276 TestModule(), 1277 sample_inputs, 1278 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1279 ) 1280 1281 def test_vulkan_backend_slice(self): 1282 class TestModule(torch.nn.Module): 1283 def __init__(self): 1284 super().__init__() 1285 1286 def forward(self, x): 1287 return x[:, 2:9:2, :] 1288 1289 sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),) 1290 1291 self.lower_module_and_test_output( 1292 TestModule(), 1293 sample_inputs, 1294 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1295 ) 1296 1297 def test_vulkan_backend_split_with_sizes(self): 1298 class TestModule(torch.nn.Module): 1299 def __init__(self): 1300 super().__init__() 1301 1302 def forward(self, x): 1303 return torch.split(x, (3, 6, 1, 3), dim=1) 1304 1305 sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),) 1306 1307 self.lower_module_and_test_output( 1308 TestModule(), 1309 sample_inputs, 1310 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1311 ) 1312 1313 def test_vulkan_backend_split_tensor(self): 1314 class TestModule(torch.nn.Module): 1315 def __init__(self): 1316 super().__init__() 1317 1318 def forward(self, x): 1319 return torch.tensor_split(x, 2, dim=1) 1320 1321 sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),) 1322 1323 self.lower_module_and_test_output( 1324 TestModule(), 1325 sample_inputs, 1326 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1327 ) 1328 1329 def test_vulkan_backend_clone(self): 1330 class TestModule(torch.nn.Module): 1331 def __init__(self): 1332 super().__init__() 1333 1334 def forward(self, x): 1335 return torch.clone(x) 1336 1337 sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),) 1338 1339 self.lower_module_and_test_output( 1340 TestModule(), 1341 sample_inputs, 1342 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1343 ) 1344 1345 def test_vulkan_backend_constant_pad_nd(self): 1346 class TestModule(torch.nn.Module): 1347 def __init__(self): 1348 super().__init__() 1349 1350 def forward(self, x): 1351 return torch.nn.functional.pad(x, (1, 2, 3, 4, 5, 6), "constant", 24.2) 1352 1353 sample_inputs = (torch.randn(size=(3, 7, 5, 11), dtype=torch.float32),) 1354 1355 self.lower_module_and_test_output( 1356 TestModule(), 1357 sample_inputs, 1358 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1359 ) 1360 1361 def test_vulkan_backend_repeat(self): 1362 class TestModule(torch.nn.Module): 1363 def __init__(self): 1364 super().__init__() 1365 1366 def forward(self, x): 1367 return x.repeat([2, 3, 1, 2]) 1368 1369 sample_inputs = (torch.randn(size=(3, 7, 5, 9), dtype=torch.float32),) 1370 1371 self.lower_module_and_test_output( 1372 TestModule(), 1373 sample_inputs, 1374 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1375 ) 1376 1377 def test_vulkan_backend_t_default(self): 1378 # aten.permute_copy.default is not enabled yet in partitioner 1379 class TestModule(torch.nn.Module): 1380 def __init__(self): 1381 super().__init__() 1382 1383 def forward(self, x): 1384 # torch.t is actually exported as aten::permute. 1385 return torch.t(x) 1386 1387 sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),) 1388 1389 self.lower_module_and_test_output( 1390 TestModule(), 1391 sample_inputs, 1392 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1393 ) 1394 1395 @unittest.skip( 1396 "Softmax shader with shared memory does not work with swiftshader due to potential swiftshader bug" 1397 ) 1398 def test_vulkan_backend_softmax(self): 1399 class SoftmaxModule(torch.nn.Module): 1400 def __init__(self): 1401 super().__init__() 1402 1403 def forward(self, x): 1404 x = x.softmax(dim=0) 1405 x = x.softmax(dim=1) 1406 x = x.softmax(dim=2) 1407 return x 1408 1409 sample_inputs = (torch.randn(size=(3, 2, 7), dtype=torch.float32),) 1410 1411 self.lower_module_and_test_output( 1412 SoftmaxModule(), 1413 sample_inputs, 1414 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1415 ) 1416 1417 @unittest.skip( 1418 "Softmax shader with shared memory does not work with swiftshader due to potential swiftshader bug" 1419 ) 1420 def test_vulkan_backend_logsoftmax(self): 1421 class LogSoftmaxModule(torch.nn.Module): 1422 def __init__(self): 1423 super().__init__() 1424 1425 def forward(self, x): 1426 x = x.log_softmax(dim=0) 1427 x = x.log_softmax(dim=1) 1428 x = x.log_softmax(dim=2) 1429 return x 1430 1431 sample_inputs = (torch.randn(size=(3, 2, 7), dtype=torch.float32),) 1432 1433 self.lower_module_and_test_output( 1434 LogSoftmaxModule(), 1435 sample_inputs, 1436 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1437 ) 1438 1439 def test_vulkan_backend_gelu(self): 1440 class GeluModule(torch.nn.Module): 1441 def __init__(self): 1442 super().__init__() 1443 self.gelu = torch.nn.GELU(approximate="tanh") 1444 1445 def forward(self, x): 1446 return self.gelu(x) 1447 1448 self.lower_unary_module_and_test_output(GeluModule()) 1449 1450 @unittest.skip( 1451 "Reduce shader does not support multiple reduction axes at the moment" 1452 ) 1453 def test_vulkan_backend_mean(self): 1454 class MeanModule(torch.nn.Module): 1455 def __init__(self, dims, keepdim=True): 1456 super().__init__() 1457 self.dims = dims 1458 self.keepdim = keepdim 1459 1460 def forward(self, x): 1461 return torch.mean(x, self.dims, keepdim=self.keepdim) 1462 1463 sample_inputs = ( 1464 torch.arange(end=2 * 3 * 2 * 5, dtype=torch.float32).reshape(2, 3, 2, 5), 1465 ) 1466 1467 self.lower_module_and_test_output( 1468 MeanModule(dims=[-1, -2]), 1469 sample_inputs, 1470 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1471 ) 1472 1473 self.lower_module_and_test_output( 1474 MeanModule(dims=[1]), 1475 sample_inputs, 1476 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1477 ) 1478 1479 self.lower_module_and_test_output( 1480 MeanModule(dims=[0, 1, 2, 3]), 1481 sample_inputs, 1482 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1483 ) 1484 1485 self.lower_module_and_test_output( 1486 MeanModule(dims=[-1, -2], keepdim=False), 1487 sample_inputs, 1488 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1489 ) 1490 1491 self.lower_module_and_test_output( 1492 MeanModule(dims=[1], keepdim=False), 1493 sample_inputs, 1494 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1495 ) 1496 1497 def test_vulkan_backend_index_select_int(self): 1498 class IndexSelectModule(torch.nn.Module): 1499 def __init__(self, dim, indices): 1500 super().__init__() 1501 self.dim = dim 1502 self.index = torch.tensor(indices) 1503 1504 def forward(self, x): 1505 return torch.index_select(x, self.dim, self.index) 1506 1507 sample_inputs = (torch.arange(96).reshape(2, 8, 2, 3),) 1508 1509 self.lower_module_and_test_output( 1510 IndexSelectModule(dim=1, indices=[2, 3, 5, 6, 7]), 1511 sample_inputs, 1512 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1513 ) 1514 1515 def test_vulkan_backend_index_select(self): 1516 class IndexSelectModule(torch.nn.Module): 1517 def __init__(self, dim, indices): 1518 super().__init__() 1519 self.dim = dim 1520 self.index = torch.tensor(indices) 1521 1522 def forward(self, x): 1523 return torch.index_select(x, self.dim, self.index) 1524 1525 sample_inputs = (torch.arange(144).reshape(12, 1, 3, 4).float(),) 1526 1527 self.lower_module_and_test_output( 1528 IndexSelectModule(dim=0, indices=[1, 3, 5, 7, 8, 9, 10, 11, 2, 3]), 1529 sample_inputs, 1530 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1531 ) 1532 1533 def test_vulkan_backend_arange_int(self): 1534 class ArangeModule(torch.nn.Module): 1535 def __init__(self, input): 1536 super().__init__() 1537 self.input = input 1538 1539 def forward(self, x): 1540 return torch.arange(*self.input, dtype=torch.int32) 1541 1542 # `torch.arange` could take one, two or three arguments as input. 1543 # If only one argument is provided, it will be interpreted as `end`. 1544 # If two arguments are provided, the first one will be interpreted as `start` 1545 # and the second one will be interpreted as `end`. 1546 # If three arguments are provided, the first one will be interpreted as `start`, 1547 # the second one will be interpreted as `end` and the third one will be 1548 # interpreted as `step`. 1549 inputs = [ 1550 [1], 1551 [-3, 5], 1552 [1, 11, 2], 1553 [12, 1, -2], 1554 ] 1555 for i in inputs: 1556 self.lower_module_and_test_output( 1557 ArangeModule(i), 1558 (torch.randn(size=(1,), dtype=torch.float32),), # dummy input 1559 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1560 ) 1561 1562 def test_vulkan_backend_arange_float(self): 1563 class ArangeModule(torch.nn.Module): 1564 def __init__(self, input): 1565 super().__init__() 1566 self.input = input 1567 1568 def forward(self, x): 1569 return torch.arange(*self.input) 1570 1571 inputs = [ 1572 [1.5], 1573 [-3, 5.0], 1574 [1.0, 11, 2], 1575 [12, 1, -2.0], 1576 ] 1577 for i in inputs: 1578 self.lower_module_and_test_output( 1579 ArangeModule(i), 1580 (torch.randn(size=(1,), dtype=torch.float32),), # dummy input 1581 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1582 ) 1583 1584 def test_vulkan_backend_arange_int64(self): 1585 class ArangeModule(torch.nn.Module): 1586 def __init__(self, input): 1587 super().__init__() 1588 self.input = input 1589 1590 def forward(self, x): 1591 return torch.arange(*self.input) 1592 1593 inputs = [ 1594 [1], 1595 [-3, 5], 1596 [1, 11, 2], 1597 [12, 1, -2], 1598 [1.5], 1599 [-3, 5.0], 1600 [1.0, 11, 2], 1601 [12, 1, -2.0], 1602 ] 1603 for i in inputs: 1604 self.lower_module_and_test_output( 1605 ArangeModule(i), 1606 (torch.randn(size=(1,), dtype=torch.float32),), # dummy input 1607 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1608 ) 1609 self.lower_module_and_test_output( 1610 ArangeModule(i), 1611 (torch.randint(low=-100, high=100, size=(5, 5)),), # dummy input 1612 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1613 ) 1614 1615 def test_vulkan_backend_embedding_1d(self): 1616 class EmbeddingModule(torch.nn.Module): 1617 def __init__(self, embedding): 1618 super().__init__() 1619 self.embedding = embedding 1620 1621 def forward(self, x): 1622 return self.embedding(x) 1623 1624 self.lower_module_and_test_output( 1625 EmbeddingModule(torch.nn.Embedding(5, 4)), 1626 (torch.tensor([0, 1, 0, 4, 2, 0]),), 1627 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1628 ) 1629 1630 def test_vulkan_backend_embedding_2d(self): 1631 class EmbeddingModule(torch.nn.Module): 1632 def __init__(self, embedding): 1633 super().__init__() 1634 self.embedding = embedding 1635 1636 def forward(self, x): 1637 return self.embedding(x) 1638 1639 self.lower_module_and_test_output( 1640 EmbeddingModule(torch.nn.Embedding(5, 4)), 1641 (torch.tensor([[0, 1, 0], [4, 2, 0]]),), 1642 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1643 ) 1644 1645 def test_vulkan_backend_embedding_3d(self): 1646 class EmbeddingModule(torch.nn.Module): 1647 def __init__(self, embedding): 1648 super().__init__() 1649 self.embedding = embedding 1650 1651 def forward(self, x): 1652 return self.embedding(x) 1653 1654 self.lower_module_and_test_output( 1655 EmbeddingModule(torch.nn.Embedding(5, 4)), 1656 (torch.tensor([[[0, 1], [0, 1]], [[4, 2], [3, 3]]]),), 1657 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1658 ) 1659 1660 def test_vulkan_backend_flip(self): 1661 class FlipModule(torch.nn.Module): 1662 def __init__(self): 1663 super().__init__() 1664 1665 def forward(self, x): 1666 return torch.flip(x, [0, 1, 2, 3]) 1667 1668 self.lower_module_and_test_output( 1669 FlipModule(), 1670 (torch.arange(48).reshape(2, 3, 4, 2),), 1671 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1672 ) 1673 1674 def test_vulkan_backend_conv_with_clamp(self): 1675 class ConvWithClampModule(torch.nn.Module): 1676 def __init__(self): 1677 super().__init__() 1678 self.weight = torch.randn(6, 8, 3, 3) 1679 self.bias = torch.randn(8) 1680 self.stride = (1, 2) 1681 self.padding = (2, 3) 1682 self.dilation = (1, 1) 1683 self.transposed = True 1684 self.output_padding = (0, 1) 1685 self.groups = 1 1686 self.output_min = 0 1687 self.output_max = 10 1688 1689 def forward(self, x): 1690 return torch.ops.et_vk.conv_with_clamp( 1691 x, 1692 self.weight, 1693 self.bias, 1694 self.stride, 1695 self.padding, 1696 self.dilation, 1697 self.transposed, 1698 self.output_padding, 1699 self.groups, 1700 self.output_min, 1701 self.output_max, 1702 ) 1703 1704 self.lower_module_and_test_output( 1705 ConvWithClampModule(), 1706 (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),), 1707 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1708 ) 1709 1710 def test_vulkan_backend_grid_priors(self): 1711 class GridPriorsModule(torch.nn.Module): 1712 def __init__(self): 1713 super().__init__() 1714 1715 def forward(self, x): 1716 return torch.ops.et_vk.grid_priors( 1717 x, 1718 stride=8, 1719 offset=0.5, 1720 ) 1721 1722 self.lower_module_and_test_output( 1723 GridPriorsModule(), 1724 (torch.rand(size=[1, 5, 2, 3]),), 1725 memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], 1726 ) 1727