xref: /aosp_15_r20/external/executorch/backends/vulkan/test/test_vulkan_delegate.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import ctypes
10import unittest
11from typing import Tuple
12
13import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
14
15import torch
16
17from executorch.backends.transforms.convert_dtype_pass import I64toI32
18
19from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
20from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
21
22from executorch.exir import EdgeCompileConfig
23from torch.export import Dim, export, ExportedProgram
24
25ctypes.CDLL("libvulkan.so.1")
26
27
28from executorch.exir import to_edge_transform_and_lower
29from executorch.extension.pybindings.portable_lib import (  # @manual
30    _load_for_executorch_from_buffer,
31)
32from executorch.extension.pytree import tree_flatten
33
34
35class TestBackends(unittest.TestCase):
36    _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
37        _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
38    )
39
40    def assert_outputs_equal(
41        self,
42        model_output,
43        ref_output,
44        atol=1e-03,
45        rtol=1e-03,
46        first_output_only=False,
47        equal_nan=True,
48    ):
49        """
50        Helper testing function that asserts that the model output and the reference output
51        are equal with some tolerance. Due to numerical differences between eager mode and
52        the Vulkan's backend, we relax the detal such that default absolute
53        tolerance is 1e-3. and default relative tolerance is 1e-3.
54        """
55
56        # Compare the result from executor and eager mode direclty
57        if isinstance(ref_output, tuple) or isinstance(ref_output, list):
58            # Multiple outputs executor always returns tuple, even if there is one output
59            self.assertTrue(len(ref_output) == len(model_output))
60            if first_output_only:
61                self.assertTrue(
62                    torch.allclose(
63                        model_output[0],
64                        ref_output[0],
65                        atol=atol,
66                        rtol=rtol,
67                        equal_nan=equal_nan,
68                    )
69                )
70            else:
71                for i in range(len(ref_output)):
72                    self.assertTrue(
73                        torch.allclose(
74                            model_output[i],
75                            ref_output[i],
76                            atol=atol,
77                            rtol=rtol,
78                            equal_nan=equal_nan,
79                        )
80                    )
81        else:
82            # If one output, eager returns tensor while executor tuple of size 1
83            self.assertTrue(
84                torch.allclose(
85                    model_output[0],
86                    ref_output,
87                    atol=atol,
88                    rtol=rtol,
89                    equal_nan=equal_nan,
90                )
91            )
92
93    def lower_module_and_test_output(
94        self,
95        model: torch.nn.Module,
96        sample_inputs: Tuple[torch.Tensor],
97        atol=1e-03,
98        rtol=1e-01,
99        dynamic_shapes=None,
100        test_inputs=None,
101        memory_layouts=None,
102        first_output_only=False,
103    ):
104        """
105        Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
106        the given sample inputs. It then runs the lowered module and compares its
107        outputs with the outputs of the eager module.
108        """
109
110        def run_test(memory_layout):
111            compile_options = {
112                "memory_layout_override": memory_layout,
113            }
114
115            # At least model should run in eager mode.
116            model.eval()
117            model(*sample_inputs)
118
119            program: ExportedProgram = export(
120                model, sample_inputs, dynamic_shapes=dynamic_shapes
121            )
122
123            edge_program = to_edge_transform_and_lower(
124                program,
125                transform_passes=[
126                    I64toI32(self._edge_compile_config._skip_dim_order),
127                ],
128                partitioner=[VulkanPartitioner(compile_options)],
129            )
130            executorch_program = edge_program.to_executorch()
131
132            self.assertEqual(
133                executorch_program.executorch_program.execution_plan[0].delegates[0].id,
134                VulkanBackend.__name__,
135            )
136
137            executorch_module = _load_for_executorch_from_buffer(
138                executorch_program.buffer
139            )
140            inputs_flattened, _ = tree_flatten(sample_inputs)
141
142            model_output = executorch_module.run_method(
143                "forward", tuple(inputs_flattened)
144            )
145            ref_output = model(*sample_inputs)
146
147            self.assert_outputs_equal(
148                model_output,
149                ref_output,
150                atol=atol,
151                rtol=rtol,
152                first_output_only=first_output_only,
153            )
154
155            if test_inputs is not None:
156                for test_input in test_inputs:
157                    test_inputs_flattened, _ = tree_flatten(test_input)
158                    model_output = executorch_module.run_method(
159                        "forward", tuple(test_inputs_flattened)
160                    )
161                    ref_output = model(*test_input)
162
163                    self.assert_outputs_equal(
164                        model_output,
165                        ref_output,
166                        atol=atol,
167                        rtol=rtol,
168                        first_output_only=first_output_only,
169                    )
170
171        memory_layouts_to_test = [
172            vk_graph_schema.VkMemoryLayout.TENSOR_WIDTH_PACKED,
173            vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED,
174        ]
175
176        if memory_layouts is not None:
177            memory_layouts_to_test = memory_layouts
178
179        for memory_layout in memory_layouts_to_test:
180            run_test(memory_layout)
181
182    def test_vulkan_backend_add(self):
183        # This test is the simplest test by manually lowering some submodules, we can use paritioner
184        # for auto detecting lowerable parts.
185        class AddModule(torch.nn.Module):
186            def __init__(self):
187                super().__init__()
188
189            def forward(self, x, y, w):
190                z = x + y
191                z = z + x
192                z = z + x
193                z = z + w
194                z = w + z
195                z = z + 3  # test scalar broadcasting
196                return z
197
198        add_module = AddModule()
199        sample_inputs = (
200            torch.rand(size=(2, 3), dtype=torch.float32),
201            torch.rand(size=(2, 3), dtype=torch.float32),
202            torch.rand(size=(2, 1), dtype=torch.float32),  # test broadcasting
203        )
204
205        self.lower_module_and_test_output(add_module, sample_inputs)
206
207        sample_inputs = (
208            torch.rand(size=(4, 5, 2, 3), dtype=torch.float32),
209            torch.rand(size=(4, 5, 2, 3), dtype=torch.float32),
210            torch.rand(
211                size=(2, 3), dtype=torch.float32
212            ),  # test broadcasting on packed dim
213        )
214
215        self.lower_module_and_test_output(add_module, sample_inputs)
216
217    def test_vulkan_backend_add_int(self):
218        class AddIntModule(torch.nn.Module):
219            def __init__(self):
220                super().__init__()
221
222            def forward(self, x, y):
223                z = x + y
224                return z
225
226        add_int_module = AddIntModule()
227        sample_inputs = (
228            torch.randint(low=-100, high=100, size=(2, 3), dtype=torch.int32),
229            torch.randint(low=-100, high=100, size=(2, 3), dtype=torch.int32),
230        )
231
232        self.lower_module_and_test_output(add_int_module, sample_inputs)
233
234    def test_vulkan_backend_zero_dim_tensor(self):
235        class ZeroDimModule(torch.nn.Module):
236            def __init__(self):
237                super().__init__()
238                self.zero = torch.full([], 1.3, dtype=torch.float32)
239
240            def forward(self, x):
241                return x + self.zero
242
243        internal_data_module = ZeroDimModule()
244        sample_inputs = (torch.rand(size=(2, 3), dtype=torch.float32),)
245        self.lower_module_and_test_output(internal_data_module, sample_inputs)
246
247    def test_vulkan_backend_internal_data(self):
248        class InternalDataModule(torch.nn.Module):
249            def __init__(self):
250                super().__init__()
251                self.weight = torch.rand(size=(2, 3), dtype=torch.float32)
252
253            def forward(self, x, y):
254                inter1 = torch.add(x, y, alpha=2)
255                inter2 = torch.add(x, y, alpha=3.14)
256                inter3 = inter1 * self.weight
257                inter4 = inter2 * self.weight
258                return inter4 - inter3
259
260        internal_data_module = InternalDataModule()
261        sample_inputs = (
262            torch.rand(size=(2, 3), dtype=torch.float32),
263            torch.rand(size=(2, 3), dtype=torch.float32),
264        )
265
266        self.lower_module_and_test_output(internal_data_module, sample_inputs)
267
268    def test_vulkan_backend_sub(self):
269        class SubModule(torch.nn.Module):
270            def __init__(self):
271                super().__init__()
272
273            def forward(self, x, y):
274                z = torch.sub(x, y, alpha=2)
275                z = torch.sub(z, x, alpha=3.14)
276                z = z - x
277                return z
278
279        sub_module = SubModule()
280        sample_inputs = (
281            torch.rand(size=(2, 3), dtype=torch.float32),
282            torch.rand(size=(2, 3), dtype=torch.float32),
283        )
284
285        self.lower_module_and_test_output(sub_module, sample_inputs)
286
287    def test_vulkan_backend_mul(self):
288        class MulModule(torch.nn.Module):
289            def __init__(self):
290                super().__init__()
291
292            def forward(self, x, y):
293                z = x * y
294                z = z * x
295                z = z * x
296                return z
297
298        mul_module = MulModule()
299        sample_inputs = (
300            torch.rand(size=(2, 3), dtype=torch.float32),
301            torch.rand(size=(2, 3), dtype=torch.float32),
302        )
303
304        self.lower_module_and_test_output(mul_module, sample_inputs)
305
306    def test_vulkan_backend_div(self):
307        class DivModule(torch.nn.Module):
308            def __init__(self):
309                super().__init__()
310
311            def forward(self, x, y):
312                z = x / y
313                z = z / x
314                z = z / x
315                return z
316
317        div_module = DivModule()
318        sample_inputs = (
319            torch.rand(size=(2, 3), dtype=torch.float32),
320            torch.rand(size=(2, 3), dtype=torch.float32),
321        )
322
323        self.lower_module_and_test_output(div_module, sample_inputs)
324
325    def test_vulkan_backend_arithmetic(self):
326        class ArithmeticModule(torch.nn.Module):
327            def __init__(self):
328                super().__init__()
329                self.weight = torch.rand(size=(2, 3), dtype=torch.float32)
330
331            def forward(self, x, y):
332                z = x + y
333                z = z - x
334                z = z / x
335                z = z * self.weight
336                return z
337
338        arithmetic_module = ArithmeticModule()
339        sample_inputs = (
340            torch.rand(size=(2, 3), dtype=torch.float32),
341            torch.rand(size=(2, 3), dtype=torch.float32),
342        )
343
344        self.lower_module_and_test_output(arithmetic_module, sample_inputs)
345
346    def test_vulkan_backend_floor_div(self):
347        class FloorDivModule(torch.nn.Module):
348            def __init__(self):
349                super().__init__()
350
351            def forward(self, x, y):
352                z = x // y
353                return z
354
355        floor_div_module = FloorDivModule()
356        sample_inputs = (
357            torch.rand(size=(2, 3), dtype=torch.float32) * 10.0,
358            torch.rand(size=(2, 3), dtype=torch.float32) + 1.0,
359        )
360
361        # absolute tolerance is 1 because of flooring
362        self.lower_module_and_test_output(
363            floor_div_module, sample_inputs, atol=1.0 + 1e-03
364        )
365
366    def test_vulkan_backend_pow(self):
367        class PowModule(torch.nn.Module):
368            def __init__(self):
369                super().__init__()
370
371            def forward(self, x, y):
372                z = torch.pow(x, y)
373                return z
374
375        pow_module = PowModule()
376        sample_inputs = (
377            torch.rand(size=(2, 3), dtype=torch.float32),
378            torch.rand(size=(2, 3), dtype=torch.float32),
379        )
380
381        self.lower_module_and_test_output(pow_module, sample_inputs)
382
383    def lower_unary_module_and_test_output(self, module):
384        batch = Dim("batch", max=8)
385        sample_inputs = (torch.randn(8, 16, 96, 92),)
386
387        dynamic_shapes = {"x": {0: batch}}
388        test_inputs = [
389            (torch.randn(3, 14, 15, 92),),
390            (torch.randn(6, 5, 35, 89),),
391            (torch.randn(7, 9, 32, 38),),
392        ]
393
394        self.lower_module_and_test_output(
395            module,
396            sample_inputs,
397            dynamic_shapes=dynamic_shapes,
398            test_inputs=test_inputs,
399        )
400
401    def test_vulkan_backend_clamp(self):
402        class ClampModule(torch.nn.Module):
403            def __init__(self):
404                super().__init__()
405
406            def forward(self, x):
407                return torch.clamp(x, min=-3.14)
408
409        self.lower_unary_module_and_test_output(ClampModule())
410
411    def test_vulkan_backend_clamp_int(self):
412        class ClampModule(torch.nn.Module):
413            def __init__(self):
414                super().__init__()
415
416            def forward(self, x):
417                return torch.clamp(x, min=-3)
418
419        sample_inputs = (
420            torch.randint(low=-100, high=100, size=(5, 5), dtype=torch.int32),
421        )
422
423        self.lower_module_and_test_output(ClampModule(), sample_inputs)
424
425    def test_vulkan_backend_clamp_int64(self):
426        class ClampModule(torch.nn.Module):
427            def __init__(self):
428                super().__init__()
429
430            def forward(self, x):
431                return torch.clamp(x, min=-3)
432
433        sample_inputs = (
434            torch.randint(low=-100, high=100, size=(5, 5), dtype=torch.int64),
435        )
436
437        self.lower_module_and_test_output(ClampModule(), sample_inputs)
438
439    def test_vulkan_backend_cos(self):
440        class CosModule(torch.nn.Module):
441            def __init__(self):
442                super().__init__()
443
444            def forward(self, x):
445                return torch.cos(x)
446
447        self.lower_unary_module_and_test_output(CosModule())
448
449    def test_vulkan_backend_hardtanh(self):
450        class HardTanHModule(torch.nn.Module):
451            def __init__(self):
452                super().__init__()
453                self.tanh = torch.nn.Hardtanh(min_val=-3.14, max_val=6.28)
454
455            def forward(self, x):
456                return self.tanh(x)
457
458        self.lower_unary_module_and_test_output(HardTanHModule())
459
460    def test_vulkan_backend_exp(self):
461        class ExpModule(torch.nn.Module):
462            def __init__(self):
463                super().__init__()
464
465            def forward(self, x):
466                return torch.exp(x)
467
468        self.lower_unary_module_and_test_output(ExpModule())
469
470    def test_vulkan_backend_neg(self):
471        class NegModule(torch.nn.Module):
472            def __init__(self):
473                super().__init__()
474
475            def forward(self, x):
476                return torch.neg(x)
477
478        self.lower_unary_module_and_test_output(NegModule())
479
480    def test_vulkan_backend_sin(self):
481        class SinModule(torch.nn.Module):
482            def __init__(self):
483                super().__init__()
484
485            def forward(self, x):
486                return torch.sin(x)
487
488        self.lower_unary_module_and_test_output(SinModule())
489
490    def test_vulkan_backend_relu(self):
491        class ReLUModule(torch.nn.Module):
492            def __init__(self):
493                super().__init__()
494
495            def forward(self, x):
496                return torch.relu(x)
497
498        self.lower_unary_module_and_test_output(ReLUModule())
499
500    def test_vulkan_backend_sqrt(self):
501        class SqrtModule(torch.nn.Module):
502            def __init__(self):
503                super().__init__()
504
505            def forward(self, x):
506                return torch.sqrt(x)
507
508        self.lower_unary_module_and_test_output(SqrtModule())
509
510    def test_vulkan_backend_hardshrink(self):
511        class HardshrinkModule(torch.nn.Module):
512            def __init__(self):
513                super().__init__()
514                self.hardshrink = torch.nn.Hardshrink(lambd=0.3)
515
516            def forward(self, x):
517                return self.hardshrink(x)
518
519        self.lower_unary_module_and_test_output(HardshrinkModule())
520
521    def test_vulkan_backend_max_pool2d(self):
522        class MaxPool2dModule(torch.nn.Module):
523            def __init__(self):
524                super().__init__()
525                self.max_pool = torch.nn.MaxPool2d(
526                    kernel_size=(2, 3),
527                    stride=(1, 1),
528                    padding=0,
529                    dilation=1,
530                    ceil_mode=False,
531                    return_indices=True,
532                )
533
534            def forward(self, x):
535                return self.max_pool(x)
536
537        max_pool2d_module = MaxPool2dModule()
538        sample_inputs = (torch.randn(5, 13, 55, 68),)
539
540        batch = Dim("batch", max=8)
541        dynamic_shapes = {"x": {0: batch}}
542        test_inputs = [
543            (torch.randn(3, 14, 15, 9),),
544            (torch.randn(1, 1, 4, 6),),
545            (torch.randn(5, 10, 50, 40),),
546        ]
547        self.lower_module_and_test_output(
548            max_pool2d_module,
549            sample_inputs,
550            dynamic_shapes=dynamic_shapes,
551            test_inputs=test_inputs,
552            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
553            first_output_only=True,
554        )
555
556    def test_vulkan_backend_avg_pool2d(self):
557        class AvgPool2dModule(torch.nn.Module):
558            def __init__(self):
559                super().__init__()
560                self.avg_pool = torch.nn.AvgPool2d(
561                    kernel_size=(4, 4),
562                    stride=(4, 4),
563                    padding=(0, 0),
564                    ceil_mode=True,
565                    count_include_pad=True,
566                    divisor_override=None,
567                )
568
569            def forward(self, x):
570                return self.avg_pool(x)
571
572        avg_pool2d_module = AvgPool2dModule()
573        sample_inputs = (torch.randn(5, 13, 55, 68),)
574
575        batch = Dim("batch", max=8)
576        dynamic_shapes = {"x": {0: batch}}
577        test_inputs = [
578            (torch.randn(3, 14, 15, 9),),
579            (torch.randn(1, 1, 4, 6),),
580            (torch.randn(5, 10, 50, 40),),
581        ]
582        self.lower_module_and_test_output(
583            avg_pool2d_module,
584            sample_inputs,
585            dynamic_shapes=dynamic_shapes,
586            test_inputs=test_inputs,
587            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
588        )
589
590    def test_vulkan_backend_abs(self):
591        class AbsModule(torch.nn.Module):
592            def __init__(self):
593                super().__init__()
594
595            def forward(self, x):
596                return torch.abs(x)
597
598        self.lower_unary_module_and_test_output(AbsModule())
599
600    def test_vulkan_backend_sigmoid(self):
601        class SigmoidModule(torch.nn.Module):
602            def __init__(self):
603                super().__init__()
604
605            def forward(self, x):
606                return torch.sigmoid(x)
607
608        self.lower_unary_module_and_test_output(SigmoidModule())
609
610    def test_vulkan_backend_tanh(self):
611        class TanhModule(torch.nn.Module):
612            def __init__(self):
613                super().__init__()
614
615            def forward(self, x):
616                return torch.tanh(x)
617
618        self.lower_unary_module_and_test_output(TanhModule())
619
620    def test_vulkan_backend_linear(self):
621        class LinearModule(torch.nn.Module):
622            def __init__(self):
623                super().__init__()
624                self.linear = torch.nn.Linear(128, 64, bias=False)
625
626            def forward(self, x):
627                return self.linear(x)
628
629        module = LinearModule()
630        sample_inputs = (torch.rand(size=(32, 128), dtype=torch.float32),)
631        batch = Dim("batch", max=32)
632        dynamic_shapes = {"x": {0: batch}}
633
634        test_inputs = [
635            (torch.rand(15, 128),),
636            (torch.rand(6, 128),),
637            (torch.rand(30, 128),),
638            (torch.rand(20, 128),),
639            (torch.rand(19, 128),),
640        ]
641
642        self.lower_module_and_test_output(
643            module,
644            sample_inputs,
645            dynamic_shapes=dynamic_shapes,
646            test_inputs=test_inputs,
647        )
648
649    def test_vulkan_backend_partial(self):
650        class SimpleModel(torch.nn.Module):
651            def __init__(self):
652                super().__init__()
653                self.linear = torch.nn.Linear(10, 10)
654                self.offset_1 = torch.rand(size=(2, 10), dtype=torch.float32)
655                self.offset_2 = torch.rand(size=(2, 10), dtype=torch.float32)
656
657            def forward(self, x):
658                return self.linear(x + self.offset_1) - self.offset_2
659
660        model = SimpleModel()
661        sample_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),)
662
663        self.lower_module_and_test_output(model, sample_inputs)
664
665    def test_vulkan_backend_partial_dynamic_shapes(self):
666        class SimpleModel(torch.nn.Module):
667            def __init__(self):
668                super().__init__()
669                self.branch1 = torch.nn.Sequential(
670                    torch.nn.Linear(64, 64), torch.nn.ReLU()
671                )
672                self.branch2 = torch.nn.Sequential(
673                    torch.nn.Linear(128, 64), torch.nn.ReLU()
674                )
675                self.buffer_1 = torch.ones((1, 64)) * 0.5
676                self.buffer_2 = torch.ones((1, 64)) * 1.4
677
678            def forward(self, x1, x2):
679                out1 = self.branch1(x1)
680                out2 = self.branch2(x2)
681                return (out1 + self.buffer_1 + out2) * self.buffer_2
682
683        model = SimpleModel()
684        sample_inputs = (torch.randn(32, 64), torch.randn(32, 128))
685        batch = Dim("batch", max=32)
686        dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
687
688        test_inputs = [
689            (torch.randn(15, 64), torch.randn(15, 128)),
690            (torch.randn(6, 64), torch.randn(6, 128)),
691            (torch.randn(30, 64), torch.randn(30, 128)),
692            (torch.randn(20, 64), torch.randn(20, 128)),
693            (torch.randn(19, 64), torch.randn(19, 128)),
694        ]
695
696        self.lower_module_and_test_output(
697            model, sample_inputs, dynamic_shapes=dynamic_shapes, test_inputs=test_inputs
698        )
699
700    def test_vulkan_backend_matmul(self):
701        class MatMulModule(torch.nn.Module):
702            def __init__(self):
703                super().__init__()
704                self.weight = torch.ones(size=(63, 22), dtype=torch.float32)
705
706            def forward(self, x):
707                return torch.matmul(x, self.weight)
708
709        module = MatMulModule()
710        sample_inputs = (torch.ones(size=(31, 63), dtype=torch.float32),)
711
712        self.lower_module_and_test_output(module, sample_inputs)
713
714    def test_vulkan_backend_bmm(self):
715        class BMMModule(torch.nn.Module):
716            def __init__(self):
717                super().__init__()
718                self.weight = torch.randn(size=(4, 4, 5), dtype=torch.float32)
719
720            def forward(self, x):
721                return torch.bmm(x, self.weight)
722
723        module = BMMModule()
724        sample_inputs = (torch.randn(size=(4, 3, 4), dtype=torch.float32),)
725
726        self.lower_module_and_test_output(module, sample_inputs)
727
728    @unittest.skip(
729        "Reduce shader does not support multiple reduction axes at the moment"
730    )
731    def test_vulkan_backend_sum_dim_list(self):
732        class SumModule(torch.nn.Module):
733            def __init__(self):
734                super().__init__()
735
736            def forward(self, x):
737                x = torch.sum(x, (0, -1), keepdim=True)
738                x = torch.sum(x, 2, keepdim=False)
739                return x
740
741        module = SumModule()
742        sample_inputs = (torch.ones(size=(3, 2, 7, 5), dtype=torch.float32),)
743
744        self.lower_module_and_test_output(
745            module,
746            sample_inputs,
747            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
748        )
749
750    @unittest.skip(
751        "Reduce shader does not support multiple reduction axes at the moment"
752    )
753    def test_vulkan_backend_sum(self):
754        class SumModule(torch.nn.Module):
755            def __init__(self):
756                super().__init__()
757
758            def forward(self, x):
759                x = torch.sum(x, (), keepdim=True)
760                x = torch.sum(x)
761                return x
762
763        module = SumModule()
764        sample_inputs = (torch.rand(size=(3, 2, 7, 5), dtype=torch.float32),)
765
766        self.lower_module_and_test_output(
767            module,
768            sample_inputs,
769            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
770        )
771
772    def test_vulkan_backend_conv2d(self):
773        class Conv2dModule(torch.nn.Module):
774            def __init__(self):
775                super().__init__()
776                self.conv = torch.nn.Conv2d(
777                    in_channels=6,
778                    out_channels=8,
779                    kernel_size=(3, 3),
780                    padding=(2, 3),
781                    stride=(1, 2),
782                    dilation=1,
783                    groups=1,
784                    bias=True,
785                )
786
787            def forward(self, x):
788                return self.conv(x)
789
790        conv2d_module = Conv2dModule()
791        sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),)
792
793        self.lower_module_and_test_output(
794            conv2d_module,
795            sample_inputs,
796            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
797        )
798
799    def test_vulkan_backend_conv_transpose2d(self):
800        class ConvTranspose2dModule(torch.nn.Module):
801            def __init__(self):
802                super().__init__()
803                self.conv = torch.nn.ConvTranspose2d(
804                    in_channels=6,
805                    out_channels=8,
806                    kernel_size=(3, 3),
807                    padding=(2, 3),
808                    stride=(1, 2),
809                    output_padding=(0, 1),
810                    dilation=1,
811                    groups=1,
812                    bias=True,
813                )
814
815            def forward(self, x):
816                return self.conv(x)
817
818        conv_transpose2d_module = ConvTranspose2dModule()
819        sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),)
820
821        self.lower_module_and_test_output(
822            conv_transpose2d_module,
823            sample_inputs,
824            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
825        )
826
827    def test_vulkan_backend_conv2d_dw(self):
828        class Conv2dModule(torch.nn.Module):
829            def __init__(self):
830                super().__init__()
831                self.conv = torch.nn.Conv2d(
832                    in_channels=8,
833                    out_channels=8,
834                    kernel_size=3,
835                    padding=1,
836                    groups=8,
837                    bias=True,
838                )
839
840            def forward(self, x):
841                return self.conv(x)
842
843        conv2d_module = Conv2dModule()
844        sample_inputs = (torch.randn(size=(1, 8, 72, 96), dtype=torch.float32),)
845
846        self.lower_module_and_test_output(
847            conv2d_module,
848            sample_inputs,
849            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
850        )
851
852    def test_vulkan_backend_conv2d_pw(self):
853        class Conv2dModule(torch.nn.Module):
854            def __init__(self):
855                super().__init__()
856                self.conv = torch.nn.Conv2d(
857                    in_channels=8,
858                    out_channels=8,
859                    kernel_size=1,
860                    padding=1,
861                    groups=1,
862                    bias=True,
863                )
864
865            def forward(self, x):
866                return self.conv(x)
867
868        conv2d_module = Conv2dModule()
869        sample_inputs = (torch.randn(size=(1, 8, 72, 96), dtype=torch.float32),)
870
871        self.lower_module_and_test_output(
872            conv2d_module,
873            sample_inputs,
874            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
875        )
876
877    def test_vulkan_backend_conv2d_bias_false(self):
878        class Conv2dModule(torch.nn.Module):
879            def __init__(self):
880                super().__init__()
881                self.conv = torch.nn.Conv2d(
882                    in_channels=6,
883                    out_channels=8,
884                    kernel_size=(3, 3),
885                    padding=(2, 3),
886                    stride=(1, 2),
887                    dilation=1,
888                    groups=1,
889                    bias=False,
890                )
891
892            def forward(self, x):
893                return self.conv(x)
894
895        conv2d_module = Conv2dModule()
896        sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),)
897
898        self.lower_module_and_test_output(
899            conv2d_module,
900            sample_inputs,
901            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
902        )
903
904    def test_vulkan_backend_conv1d(self):
905        class Conv1dModule(torch.nn.Module):
906            def __init__(self):
907                super().__init__()
908                self.conv = torch.nn.Conv1d(
909                    in_channels=20,
910                    out_channels=10,
911                    kernel_size=6,
912                    stride=5,
913                    padding=5,
914                    dilation=3,
915                    groups=5,
916                    bias=True,
917                )
918
919            def forward(self, x):
920                return self.conv(x)
921
922        conv1d_module = Conv1dModule()
923        sample_inputs = (torch.randn(size=(3, 20, 30), dtype=torch.float32),)
924
925        self.lower_module_and_test_output(
926            conv1d_module,
927            sample_inputs,
928            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
929        )
930
931    def test_vulkan_backend_conv1d_bias_false(self):
932        class Conv1dModule(torch.nn.Module):
933            def __init__(self):
934                super().__init__()
935                self.conv = torch.nn.Conv1d(
936                    in_channels=6,
937                    out_channels=6,
938                    kernel_size=3,
939                    groups=6,
940                    bias=False,
941                )
942
943            def forward(self, x):
944                return self.conv(x)
945
946        conv1d_module = Conv1dModule()
947        sample_inputs = (torch.randn(size=(1, 6, 7), dtype=torch.float32),)
948
949        self.lower_module_and_test_output(
950            conv1d_module,
951            sample_inputs,
952            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
953        )
954
955    def test_vulkan_backend_native_layer_norm(self):
956        class NativeLayerNormModule(torch.nn.Module):
957            def __init__(self):
958                super().__init__()
959                self.layer_norm = torch.nn.LayerNorm(5)
960
961            def forward(self, x):
962                return self.layer_norm(x)
963
964        sample_inputs = (torch.randn(size=(3, 4, 5), dtype=torch.float32),)
965
966        self.lower_module_and_test_output(
967            NativeLayerNormModule(),
968            sample_inputs,
969            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
970        )
971
972    def test_vulkan_backend_batch_norm(self):
973        class BatchNormModule(torch.nn.Module):
974            def __init__(self):
975                super().__init__()
976                self.bn = torch.nn.BatchNorm2d(num_features=3)
977
978            def forward(self, x):
979                return self.bn(x)
980
981        sample_inputs = (torch.randn(size=(4, 3, 2, 5), dtype=torch.float32),)
982
983        self.lower_module_and_test_output(
984            BatchNormModule(),
985            sample_inputs,
986            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
987        )
988
989    def test_vulkan_backend_full(self):
990        class FullModule(torch.nn.Module):
991            def __init__(self):
992                super().__init__()
993
994            def forward(self, x):
995                return torch.full(x.shape, 42.0)
996
997        class ZerosModule(torch.nn.Module):
998            def __init__(self):
999                super().__init__()
1000
1001            def forward(self, x):
1002                return torch.zeros(x.shape)
1003
1004        class OnesModule(torch.nn.Module):
1005            def __init__(self):
1006                super().__init__()
1007
1008            def forward(self, x):
1009                return torch.ones(x.shape)
1010
1011        sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)
1012
1013        self.lower_module_and_test_output(
1014            FullModule(),
1015            sample_inputs,
1016            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1017        )
1018
1019        self.lower_module_and_test_output(
1020            ZerosModule(),
1021            sample_inputs,
1022            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1023        )
1024
1025        self.lower_module_and_test_output(
1026            OnesModule(),
1027            sample_inputs,
1028            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1029        )
1030
1031    def test_vulkan_backend_full_like(self):
1032        class FullLikeModule(torch.nn.Module):
1033            def __init__(self):
1034                super().__init__()
1035
1036            def forward(self, x):
1037                return torch.full_like(x, 42.0)
1038
1039        class ZerosLikeModule(torch.nn.Module):
1040            def __init__(self):
1041                super().__init__()
1042
1043            def forward(self, x):
1044                return torch.zeros_like(x)
1045
1046        class OnesLikeModule(torch.nn.Module):
1047            def __init__(self):
1048                super().__init__()
1049
1050            def forward(self, x):
1051                return torch.ones_like(x)
1052
1053        sample_inputs = (torch.randn(size=(2, 3, 4, 5), dtype=torch.float32),)
1054
1055        self.lower_module_and_test_output(
1056            FullLikeModule(),
1057            sample_inputs,
1058            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1059        )
1060
1061        self.lower_module_and_test_output(
1062            ZerosLikeModule(),
1063            sample_inputs,
1064            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1065        )
1066
1067        self.lower_module_and_test_output(
1068            OnesLikeModule(),
1069            sample_inputs,
1070            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1071        )
1072
1073    def test_vulkan_backend_upsample_nearest2d(self):
1074        class UpsampleNearest2d(torch.nn.Module):
1075            def __init__(self):
1076                super().__init__()
1077                self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest")
1078
1079            def forward(self, x):
1080                return self.upsample(x)
1081
1082        sample_inputs = (torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2),)
1083
1084        self.lower_module_and_test_output(
1085            UpsampleNearest2d(),
1086            sample_inputs,
1087            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1088        )
1089
1090    def test_vulkan_backend_minimum(self):
1091        class MinimumModule(torch.nn.Module):
1092            def __init__(self):
1093                super().__init__()
1094
1095            def forward(self, x, y):
1096                return torch.minimum(x, y)
1097
1098        sample_inputs = (
1099            torch.rand(size=(3, 5, 6, 4), dtype=torch.float32),
1100            torch.rand(size=(6, 4), dtype=torch.float32),
1101        )
1102
1103        self.lower_module_and_test_output(
1104            MinimumModule(),
1105            sample_inputs,
1106            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1107        )
1108
1109    def test_vulkan_backend_reshape(self):
1110        class ReshapeModule(torch.nn.Module):
1111            def __init__(self):
1112                super().__init__()
1113
1114            def forward(self, x):
1115                return torch.reshape(x, [-1, x.size(-1)])
1116
1117        sample_inputs = (torch.randn(size=(5, 3, 4), dtype=torch.float32),)
1118
1119        self.lower_module_and_test_output(
1120            ReshapeModule(),
1121            sample_inputs,
1122            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1123        )
1124
1125    def test_vulkan_backend_view(self):
1126        class ViewModule(torch.nn.Module):
1127            def __init__(self):
1128                super().__init__()
1129
1130            def forward(self, x):
1131                return x.view([-1, x.size(-1)])
1132
1133        sample_inputs = (torch.randn(size=(3, 2, 3, 4), dtype=torch.float32),)
1134
1135        self.lower_module_and_test_output(
1136            ViewModule(),
1137            sample_inputs,
1138            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1139        )
1140
1141    def test_vulkan_backend_view_int(self):
1142        class ViewModule(torch.nn.Module):
1143            def __init__(self):
1144                super().__init__()
1145
1146            def forward(self, x):
1147                return x.view([-1, x.size(-1)])
1148
1149        sample_inputs = (torch.randint(size=(3, 6, 2, 7), high=100, dtype=torch.int32),)
1150
1151        self.lower_module_and_test_output(
1152            ViewModule(),
1153            sample_inputs,
1154            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1155        )
1156
1157    def test_vulkan_backend_unsqueeze(self):
1158        class UnsqueezeModule(torch.nn.Module):
1159            def __init__(self):
1160                super().__init__()
1161
1162            def forward(self, x):
1163                x = torch.unsqueeze(x, 1)
1164                x = torch.unsqueeze(x, 0)
1165                return x
1166
1167        sample_inputs = (torch.randn(size=(3,), dtype=torch.float32),)
1168
1169        self.lower_module_and_test_output(
1170            UnsqueezeModule(),
1171            sample_inputs,
1172            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1173        )
1174
1175    def test_vulkan_backend_squeeze(self):
1176        class SqueezeModule(torch.nn.Module):
1177            def __init__(self):
1178                super().__init__()
1179
1180            def forward(self, x):
1181                return torch.squeeze(x, 0)
1182
1183        sample_inputs = (torch.randn(size=(1, 2, 2, 1), dtype=torch.float32),)
1184
1185        self.lower_module_and_test_output(
1186            SqueezeModule(),
1187            sample_inputs,
1188            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1189        )
1190
1191    def test_vulkan_backend_select(self):
1192        class SelectModule(torch.nn.Module):
1193            def __init__(self):
1194                super().__init__()
1195
1196            def forward(self, x):
1197                return x[0][3]
1198
1199        sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),)
1200
1201        self.lower_module_and_test_output(
1202            SelectModule(),
1203            sample_inputs,
1204            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1205        )
1206
1207    def test_vulkan_backend_permute_copy(self):
1208        class PermuteModule(torch.nn.Module):
1209            def __init__(self):
1210                super().__init__()
1211
1212            def forward(self, x):
1213                return torch.permute(x, [3, 0, 2, 1])
1214
1215        sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),)
1216
1217        self.lower_module_and_test_output(
1218            PermuteModule(),
1219            sample_inputs,
1220            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1221        )
1222
1223    def test_vulkan_backend_permute_copy_int(self):
1224        class PermuteModule(torch.nn.Module):
1225            def __init__(self):
1226                super().__init__()
1227
1228            def forward(self, x):
1229                return torch.permute(x, [3, 0, 2, 1])
1230
1231        sample_inputs = (torch.randint(size=(3, 6, 2, 7), high=100, dtype=torch.int32),)
1232
1233        self.lower_module_and_test_output(
1234            PermuteModule(),
1235            sample_inputs,
1236            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1237        )
1238
1239    def test_vulkan_backend_cat(self):
1240        class TestModule(torch.nn.Module):
1241            def __init__(self):
1242                super().__init__()
1243
1244            def forward(self, x, y, z, w):
1245                return torch.cat([x, y, z, w], dim=1)
1246
1247        sample_inputs = (
1248            torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),
1249            torch.randn(size=(3, 1, 2, 7), dtype=torch.float32),
1250            torch.randn(size=(3, 9, 2, 7), dtype=torch.float32),
1251            torch.randn(size=(3, 3, 2, 7), dtype=torch.float32),
1252        )
1253
1254        self.lower_module_and_test_output(
1255            TestModule(),
1256            sample_inputs,
1257            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1258        )
1259
1260    def test_vulkan_backend_cat_with_zero_size(self):
1261        class TestModule(torch.nn.Module):
1262            def __init__(self):
1263                super().__init__()
1264
1265            def forward(self, x, y, z, w):
1266                return torch.cat([x, y, z, w], dim=1)
1267
1268        sample_inputs = (
1269            torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),
1270            torch.randn(size=(3, 0, 2, 7), dtype=torch.float32),
1271            torch.randn(size=(3, 0, 2, 7), dtype=torch.float32),
1272            torch.randn(size=(3, 3, 2, 7), dtype=torch.float32),
1273        )
1274
1275        self.lower_module_and_test_output(
1276            TestModule(),
1277            sample_inputs,
1278            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1279        )
1280
1281    def test_vulkan_backend_slice(self):
1282        class TestModule(torch.nn.Module):
1283            def __init__(self):
1284                super().__init__()
1285
1286            def forward(self, x):
1287                return x[:, 2:9:2, :]
1288
1289        sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)
1290
1291        self.lower_module_and_test_output(
1292            TestModule(),
1293            sample_inputs,
1294            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1295        )
1296
1297    def test_vulkan_backend_split_with_sizes(self):
1298        class TestModule(torch.nn.Module):
1299            def __init__(self):
1300                super().__init__()
1301
1302            def forward(self, x):
1303                return torch.split(x, (3, 6, 1, 3), dim=1)
1304
1305        sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)
1306
1307        self.lower_module_and_test_output(
1308            TestModule(),
1309            sample_inputs,
1310            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1311        )
1312
1313    def test_vulkan_backend_split_tensor(self):
1314        class TestModule(torch.nn.Module):
1315            def __init__(self):
1316                super().__init__()
1317
1318            def forward(self, x):
1319                return torch.tensor_split(x, 2, dim=1)
1320
1321        sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)
1322
1323        self.lower_module_and_test_output(
1324            TestModule(),
1325            sample_inputs,
1326            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1327        )
1328
1329    def test_vulkan_backend_clone(self):
1330        class TestModule(torch.nn.Module):
1331            def __init__(self):
1332                super().__init__()
1333
1334            def forward(self, x):
1335                return torch.clone(x)
1336
1337        sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)
1338
1339        self.lower_module_and_test_output(
1340            TestModule(),
1341            sample_inputs,
1342            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1343        )
1344
1345    def test_vulkan_backend_constant_pad_nd(self):
1346        class TestModule(torch.nn.Module):
1347            def __init__(self):
1348                super().__init__()
1349
1350            def forward(self, x):
1351                return torch.nn.functional.pad(x, (1, 2, 3, 4, 5, 6), "constant", 24.2)
1352
1353        sample_inputs = (torch.randn(size=(3, 7, 5, 11), dtype=torch.float32),)
1354
1355        self.lower_module_and_test_output(
1356            TestModule(),
1357            sample_inputs,
1358            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1359        )
1360
1361    def test_vulkan_backend_repeat(self):
1362        class TestModule(torch.nn.Module):
1363            def __init__(self):
1364                super().__init__()
1365
1366            def forward(self, x):
1367                return x.repeat([2, 3, 1, 2])
1368
1369        sample_inputs = (torch.randn(size=(3, 7, 5, 9), dtype=torch.float32),)
1370
1371        self.lower_module_and_test_output(
1372            TestModule(),
1373            sample_inputs,
1374            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1375        )
1376
1377    def test_vulkan_backend_t_default(self):
1378        # aten.permute_copy.default is not enabled yet in partitioner
1379        class TestModule(torch.nn.Module):
1380            def __init__(self):
1381                super().__init__()
1382
1383            def forward(self, x):
1384                # torch.t is actually exported as aten::permute.
1385                return torch.t(x)
1386
1387        sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),)
1388
1389        self.lower_module_and_test_output(
1390            TestModule(),
1391            sample_inputs,
1392            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1393        )
1394
1395    @unittest.skip(
1396        "Softmax shader with shared memory does not work with swiftshader due to potential swiftshader bug"
1397    )
1398    def test_vulkan_backend_softmax(self):
1399        class SoftmaxModule(torch.nn.Module):
1400            def __init__(self):
1401                super().__init__()
1402
1403            def forward(self, x):
1404                x = x.softmax(dim=0)
1405                x = x.softmax(dim=1)
1406                x = x.softmax(dim=2)
1407                return x
1408
1409        sample_inputs = (torch.randn(size=(3, 2, 7), dtype=torch.float32),)
1410
1411        self.lower_module_and_test_output(
1412            SoftmaxModule(),
1413            sample_inputs,
1414            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1415        )
1416
1417    @unittest.skip(
1418        "Softmax shader with shared memory does not work with swiftshader due to potential swiftshader bug"
1419    )
1420    def test_vulkan_backend_logsoftmax(self):
1421        class LogSoftmaxModule(torch.nn.Module):
1422            def __init__(self):
1423                super().__init__()
1424
1425            def forward(self, x):
1426                x = x.log_softmax(dim=0)
1427                x = x.log_softmax(dim=1)
1428                x = x.log_softmax(dim=2)
1429                return x
1430
1431        sample_inputs = (torch.randn(size=(3, 2, 7), dtype=torch.float32),)
1432
1433        self.lower_module_and_test_output(
1434            LogSoftmaxModule(),
1435            sample_inputs,
1436            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1437        )
1438
1439    def test_vulkan_backend_gelu(self):
1440        class GeluModule(torch.nn.Module):
1441            def __init__(self):
1442                super().__init__()
1443                self.gelu = torch.nn.GELU(approximate="tanh")
1444
1445            def forward(self, x):
1446                return self.gelu(x)
1447
1448        self.lower_unary_module_and_test_output(GeluModule())
1449
1450    @unittest.skip(
1451        "Reduce shader does not support multiple reduction axes at the moment"
1452    )
1453    def test_vulkan_backend_mean(self):
1454        class MeanModule(torch.nn.Module):
1455            def __init__(self, dims, keepdim=True):
1456                super().__init__()
1457                self.dims = dims
1458                self.keepdim = keepdim
1459
1460            def forward(self, x):
1461                return torch.mean(x, self.dims, keepdim=self.keepdim)
1462
1463        sample_inputs = (
1464            torch.arange(end=2 * 3 * 2 * 5, dtype=torch.float32).reshape(2, 3, 2, 5),
1465        )
1466
1467        self.lower_module_and_test_output(
1468            MeanModule(dims=[-1, -2]),
1469            sample_inputs,
1470            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1471        )
1472
1473        self.lower_module_and_test_output(
1474            MeanModule(dims=[1]),
1475            sample_inputs,
1476            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1477        )
1478
1479        self.lower_module_and_test_output(
1480            MeanModule(dims=[0, 1, 2, 3]),
1481            sample_inputs,
1482            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1483        )
1484
1485        self.lower_module_and_test_output(
1486            MeanModule(dims=[-1, -2], keepdim=False),
1487            sample_inputs,
1488            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1489        )
1490
1491        self.lower_module_and_test_output(
1492            MeanModule(dims=[1], keepdim=False),
1493            sample_inputs,
1494            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1495        )
1496
1497    def test_vulkan_backend_index_select_int(self):
1498        class IndexSelectModule(torch.nn.Module):
1499            def __init__(self, dim, indices):
1500                super().__init__()
1501                self.dim = dim
1502                self.index = torch.tensor(indices)
1503
1504            def forward(self, x):
1505                return torch.index_select(x, self.dim, self.index)
1506
1507        sample_inputs = (torch.arange(96).reshape(2, 8, 2, 3),)
1508
1509        self.lower_module_and_test_output(
1510            IndexSelectModule(dim=1, indices=[2, 3, 5, 6, 7]),
1511            sample_inputs,
1512            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1513        )
1514
1515    def test_vulkan_backend_index_select(self):
1516        class IndexSelectModule(torch.nn.Module):
1517            def __init__(self, dim, indices):
1518                super().__init__()
1519                self.dim = dim
1520                self.index = torch.tensor(indices)
1521
1522            def forward(self, x):
1523                return torch.index_select(x, self.dim, self.index)
1524
1525        sample_inputs = (torch.arange(144).reshape(12, 1, 3, 4).float(),)
1526
1527        self.lower_module_and_test_output(
1528            IndexSelectModule(dim=0, indices=[1, 3, 5, 7, 8, 9, 10, 11, 2, 3]),
1529            sample_inputs,
1530            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1531        )
1532
1533    def test_vulkan_backend_arange_int(self):
1534        class ArangeModule(torch.nn.Module):
1535            def __init__(self, input):
1536                super().__init__()
1537                self.input = input
1538
1539            def forward(self, x):
1540                return torch.arange(*self.input, dtype=torch.int32)
1541
1542        # `torch.arange` could take one, two or three arguments as input.
1543        # If only one argument is provided, it will be interpreted as `end`.
1544        # If two arguments are provided, the first one will be interpreted as `start`
1545        # and the second one will be interpreted as `end`.
1546        # If three arguments are provided, the first one will be interpreted as `start`,
1547        # the second one will be interpreted as `end` and the third one will be
1548        # interpreted as `step`.
1549        inputs = [
1550            [1],
1551            [-3, 5],
1552            [1, 11, 2],
1553            [12, 1, -2],
1554        ]
1555        for i in inputs:
1556            self.lower_module_and_test_output(
1557                ArangeModule(i),
1558                (torch.randn(size=(1,), dtype=torch.float32),),  # dummy input
1559                memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1560            )
1561
1562    def test_vulkan_backend_arange_float(self):
1563        class ArangeModule(torch.nn.Module):
1564            def __init__(self, input):
1565                super().__init__()
1566                self.input = input
1567
1568            def forward(self, x):
1569                return torch.arange(*self.input)
1570
1571        inputs = [
1572            [1.5],
1573            [-3, 5.0],
1574            [1.0, 11, 2],
1575            [12, 1, -2.0],
1576        ]
1577        for i in inputs:
1578            self.lower_module_and_test_output(
1579                ArangeModule(i),
1580                (torch.randn(size=(1,), dtype=torch.float32),),  # dummy input
1581                memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1582            )
1583
1584    def test_vulkan_backend_arange_int64(self):
1585        class ArangeModule(torch.nn.Module):
1586            def __init__(self, input):
1587                super().__init__()
1588                self.input = input
1589
1590            def forward(self, x):
1591                return torch.arange(*self.input)
1592
1593        inputs = [
1594            [1],
1595            [-3, 5],
1596            [1, 11, 2],
1597            [12, 1, -2],
1598            [1.5],
1599            [-3, 5.0],
1600            [1.0, 11, 2],
1601            [12, 1, -2.0],
1602        ]
1603        for i in inputs:
1604            self.lower_module_and_test_output(
1605                ArangeModule(i),
1606                (torch.randn(size=(1,), dtype=torch.float32),),  # dummy input
1607                memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1608            )
1609            self.lower_module_and_test_output(
1610                ArangeModule(i),
1611                (torch.randint(low=-100, high=100, size=(5, 5)),),  # dummy input
1612                memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1613            )
1614
1615    def test_vulkan_backend_embedding_1d(self):
1616        class EmbeddingModule(torch.nn.Module):
1617            def __init__(self, embedding):
1618                super().__init__()
1619                self.embedding = embedding
1620
1621            def forward(self, x):
1622                return self.embedding(x)
1623
1624        self.lower_module_and_test_output(
1625            EmbeddingModule(torch.nn.Embedding(5, 4)),
1626            (torch.tensor([0, 1, 0, 4, 2, 0]),),
1627            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1628        )
1629
1630    def test_vulkan_backend_embedding_2d(self):
1631        class EmbeddingModule(torch.nn.Module):
1632            def __init__(self, embedding):
1633                super().__init__()
1634                self.embedding = embedding
1635
1636            def forward(self, x):
1637                return self.embedding(x)
1638
1639        self.lower_module_and_test_output(
1640            EmbeddingModule(torch.nn.Embedding(5, 4)),
1641            (torch.tensor([[0, 1, 0], [4, 2, 0]]),),
1642            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1643        )
1644
1645    def test_vulkan_backend_embedding_3d(self):
1646        class EmbeddingModule(torch.nn.Module):
1647            def __init__(self, embedding):
1648                super().__init__()
1649                self.embedding = embedding
1650
1651            def forward(self, x):
1652                return self.embedding(x)
1653
1654        self.lower_module_and_test_output(
1655            EmbeddingModule(torch.nn.Embedding(5, 4)),
1656            (torch.tensor([[[0, 1], [0, 1]], [[4, 2], [3, 3]]]),),
1657            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1658        )
1659
1660    def test_vulkan_backend_flip(self):
1661        class FlipModule(torch.nn.Module):
1662            def __init__(self):
1663                super().__init__()
1664
1665            def forward(self, x):
1666                return torch.flip(x, [0, 1, 2, 3])
1667
1668        self.lower_module_and_test_output(
1669            FlipModule(),
1670            (torch.arange(48).reshape(2, 3, 4, 2),),
1671            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1672        )
1673
1674    def test_vulkan_backend_conv_with_clamp(self):
1675        class ConvWithClampModule(torch.nn.Module):
1676            def __init__(self):
1677                super().__init__()
1678                self.weight = torch.randn(6, 8, 3, 3)
1679                self.bias = torch.randn(8)
1680                self.stride = (1, 2)
1681                self.padding = (2, 3)
1682                self.dilation = (1, 1)
1683                self.transposed = True
1684                self.output_padding = (0, 1)
1685                self.groups = 1
1686                self.output_min = 0
1687                self.output_max = 10
1688
1689            def forward(self, x):
1690                return torch.ops.et_vk.conv_with_clamp(
1691                    x,
1692                    self.weight,
1693                    self.bias,
1694                    self.stride,
1695                    self.padding,
1696                    self.dilation,
1697                    self.transposed,
1698                    self.output_padding,
1699                    self.groups,
1700                    self.output_min,
1701                    self.output_max,
1702                )
1703
1704        self.lower_module_and_test_output(
1705            ConvWithClampModule(),
1706            (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),),
1707            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1708        )
1709
1710    def test_vulkan_backend_grid_priors(self):
1711        class GridPriorsModule(torch.nn.Module):
1712            def __init__(self):
1713                super().__init__()
1714
1715            def forward(self, x):
1716                return torch.ops.et_vk.grid_priors(
1717                    x,
1718                    stride=8,
1719                    offset=0.5,
1720                )
1721
1722        self.lower_module_and_test_output(
1723            GridPriorsModule(),
1724            (torch.rand(size=[1, 5, 2, 3]),),
1725            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1726        )
1727