xref: /aosp_15_r20/external/pytorch/test/test_vulkan.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests
8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
9*da0073e9SAndroid Build Coastguard Workerimport io
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless(torch.is_vulkan_available(),
12*da0073e9SAndroid Build Coastguard Worker                     "Vulkan backend must be available for these tests.")
13*da0073e9SAndroid Build Coastguard Workerclass TestVulkanRewritePass(TestCase):
14*da0073e9SAndroid Build Coastguard Worker    @staticmethod
15*da0073e9SAndroid Build Coastguard Worker    def validate_transformed_module(
16*da0073e9SAndroid Build Coastguard Worker            # To please flake
17*da0073e9SAndroid Build Coastguard Worker            self,
18*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
19*da0073e9SAndroid Build Coastguard Worker            data_shape,
20*da0073e9SAndroid Build Coastguard Worker            prepack_removal=False,
21*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=False):
22*da0073e9SAndroid Build Coastguard Worker        module_instance = self
23*da0073e9SAndroid Build Coastguard Worker        scripted_model = torch.jit.script(module_instance)
24*da0073e9SAndroid Build Coastguard Worker        scripted_model.eval()
25*da0073e9SAndroid Build Coastguard Worker        input_data = torch.normal(1, 20, size=data_shape)
26*da0073e9SAndroid Build Coastguard Worker        ref_result = scripted_model(input_data)
27*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
28*da0073e9SAndroid Build Coastguard Worker        if fuse_clamping_ops or prepack_removal:
29*da0073e9SAndroid Build Coastguard Worker            scripted_model._c = torch._C._freeze_module(scripted_model._c)
30*da0073e9SAndroid Build Coastguard Worker        if fuse_clamping_ops:
31*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
32*da0073e9SAndroid Build Coastguard Worker        if prepack_removal:
33*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
36*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_model, buffer)
37*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
38*da0073e9SAndroid Build Coastguard Worker        deserialized_scripted_model = torch.jit.load(buffer)
39*da0073e9SAndroid Build Coastguard Worker        for pattern, v in pattern_count_map.items():
40*da0073e9SAndroid Build Coastguard Worker            if (v == 0):
41*da0073e9SAndroid Build Coastguard Worker                FileCheck().check(pattern).run(deserialized_scripted_model.graph)
42*da0073e9SAndroid Build Coastguard Worker            elif (v == -1):
43*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
44*da0073e9SAndroid Build Coastguard Worker            else:
45*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def test_conv(self):
48*da0073e9SAndroid Build Coastguard Worker        # Conv params
49*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
50*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group = 6
51*da0073e9SAndroid Build Coastguard Worker        height = 16
52*da0073e9SAndroid Build Coastguard Worker        width = 16
53*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group = 6
54*da0073e9SAndroid Build Coastguard Worker        groups = 4
55*da0073e9SAndroid Build Coastguard Worker        kernel_h = kernel_w = 3
56*da0073e9SAndroid Build Coastguard Worker        stride_h = stride_w = 1
57*da0073e9SAndroid Build Coastguard Worker        pad_h = pad_w = 1
58*da0073e9SAndroid Build Coastguard Worker        dilation = 1
59*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
60*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
61*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
62*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
63*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
64*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
65*da0073e9SAndroid Build Coastguard Worker        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
66*da0073e9SAndroid Build Coastguard Worker        conv_bias_shape = (output_channels)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        class Conv2D(torch.nn.Module):
69*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
70*da0073e9SAndroid Build Coastguard Worker                super().__init__()
71*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
72*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
73*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
74*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
75*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
76*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
79*da0073e9SAndroid Build Coastguard Worker                return F.conv2d(x, self.weight, self.bias,
80*da0073e9SAndroid Build Coastguard Worker                                self.strides, self.paddings, self.dilations, self.groups)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        data_shape = (batch_size, input_channels, height, width)
83*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {"Tensor = aten::conv2d": -1,
84*da0073e9SAndroid Build Coastguard Worker                             "vulkan_prepack::conv2d_clamp_prepack": 1,
85*da0073e9SAndroid Build Coastguard Worker                             "vulkan_prepack::conv2d_clamp_run": 1}
86*da0073e9SAndroid Build Coastguard Worker        TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        class Conv2DRelu(torch.nn.Module):
89*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
90*da0073e9SAndroid Build Coastguard Worker                super().__init__()
91*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
92*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
93*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
94*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
95*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
96*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
99*da0073e9SAndroid Build Coastguard Worker                o = F.conv2d(x, self.weight, self.bias,
100*da0073e9SAndroid Build Coastguard Worker                             self.strides, self.paddings, self.dilations, self.groups)
101*da0073e9SAndroid Build Coastguard Worker                o = F.relu(o)
102*da0073e9SAndroid Build Coastguard Worker                return o
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        data_shape = (batch_size, input_channels, height, width)
105*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {"Tensor = aten::conv2d": -1,
106*da0073e9SAndroid Build Coastguard Worker                             "vulkan_prepack::conv2d_clamp_prepack": 1,
107*da0073e9SAndroid Build Coastguard Worker                             "vulkan_prepack::conv2d_clamp_run": 1}
108*da0073e9SAndroid Build Coastguard Worker        TestVulkanRewritePass.validate_transformed_module(
109*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(), pattern_count_map, data_shape)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::relu"] = 1
112*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
113*da0073e9SAndroid Build Coastguard Worker        TestVulkanRewritePass.validate_transformed_module(
114*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(),
115*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
116*da0073e9SAndroid Build Coastguard Worker            data_shape,
117*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True)
118*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::relu"] = -1
119*da0073e9SAndroid Build Coastguard Worker        TestVulkanRewritePass.validate_transformed_module(
120*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(),
121*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
122*da0073e9SAndroid Build Coastguard Worker            data_shape,
123*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
124*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        class Conv2DHardtanh(torch.nn.Module):
128*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
129*da0073e9SAndroid Build Coastguard Worker                super().__init__()
130*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
131*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
132*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
133*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
134*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
135*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
138*da0073e9SAndroid Build Coastguard Worker                o = F.conv2d(x, self.weight, self.bias,
139*da0073e9SAndroid Build Coastguard Worker                             self.strides, self.paddings, self.dilations, self.groups)
140*da0073e9SAndroid Build Coastguard Worker                o = F.hardtanh(o)
141*da0073e9SAndroid Build Coastguard Worker                return o
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        data_shape = (batch_size, input_channels, height, width)
144*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {"Tensor = aten::conv2d": -1,
145*da0073e9SAndroid Build Coastguard Worker                             "vulkan_prepack::conv2d_clamp_prepack": 1,
146*da0073e9SAndroid Build Coastguard Worker                             "vulkan_prepack::conv2d_clamp_run": 1}
147*da0073e9SAndroid Build Coastguard Worker        TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
148*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::hardtanh"] = 1
149*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
150*da0073e9SAndroid Build Coastguard Worker        TestVulkanRewritePass.validate_transformed_module(
151*da0073e9SAndroid Build Coastguard Worker            Conv2DHardtanh(),
152*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
153*da0073e9SAndroid Build Coastguard Worker            data_shape,
154*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True)
155*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::hardtanh"] = -1
156*da0073e9SAndroid Build Coastguard Worker        TestVulkanRewritePass.validate_transformed_module(
157*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(),
158*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
159*da0073e9SAndroid Build Coastguard Worker            data_shape,
160*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
161*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
164*da0073e9SAndroid Build Coastguard Worker    run_tests()
165