xref: /aosp_15_r20/external/pytorch/test/test_metal.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
8*da0073e9SAndroid Build Coastguard Workerimport io
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerclass TestMetalRewritePass(TestCase):
11*da0073e9SAndroid Build Coastguard Worker    @staticmethod
12*da0073e9SAndroid Build Coastguard Worker    def validate_transformed_module(
13*da0073e9SAndroid Build Coastguard Worker            # To please flake
14*da0073e9SAndroid Build Coastguard Worker            self,
15*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
16*da0073e9SAndroid Build Coastguard Worker            data_shape,
17*da0073e9SAndroid Build Coastguard Worker            prepack_removal=False,
18*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=False):
19*da0073e9SAndroid Build Coastguard Worker        module_instance = self
20*da0073e9SAndroid Build Coastguard Worker        scripted_model = torch.jit.script(module_instance)
21*da0073e9SAndroid Build Coastguard Worker        scripted_model.eval()
22*da0073e9SAndroid Build Coastguard Worker        input_data = torch.normal(1, 20, size=data_shape)
23*da0073e9SAndroid Build Coastguard Worker        ref_result = scripted_model(input_data)
24*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c)
25*da0073e9SAndroid Build Coastguard Worker        if fuse_clamping_ops or prepack_removal:
26*da0073e9SAndroid Build Coastguard Worker            scripted_model._c = torch._C._freeze_module(scripted_model._c)
27*da0073e9SAndroid Build Coastguard Worker        if fuse_clamping_ops:
28*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c)
29*da0073e9SAndroid Build Coastguard Worker        if prepack_removal:
30*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
33*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_model, buffer)
34*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
35*da0073e9SAndroid Build Coastguard Worker        deserialized_scripted_model = torch.jit.load(buffer)
36*da0073e9SAndroid Build Coastguard Worker        for pattern, v in pattern_count_map.items():
37*da0073e9SAndroid Build Coastguard Worker            if (v == 0):
38*da0073e9SAndroid Build Coastguard Worker                FileCheck().check(pattern).run(deserialized_scripted_model.graph)
39*da0073e9SAndroid Build Coastguard Worker            elif (v == -1):
40*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
41*da0073e9SAndroid Build Coastguard Worker            else:
42*da0073e9SAndroid Build Coastguard Worker                FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    def test_conv(self):
45*da0073e9SAndroid Build Coastguard Worker        # Conv params
46*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
47*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group = 6
48*da0073e9SAndroid Build Coastguard Worker        height = 16
49*da0073e9SAndroid Build Coastguard Worker        width = 16
50*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group = 6
51*da0073e9SAndroid Build Coastguard Worker        groups = 4
52*da0073e9SAndroid Build Coastguard Worker        kernel_h = kernel_w = 3
53*da0073e9SAndroid Build Coastguard Worker        stride_h = stride_w = 1
54*da0073e9SAndroid Build Coastguard Worker        pad_h = pad_w = 1
55*da0073e9SAndroid Build Coastguard Worker        dilation = 1
56*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
57*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
58*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
59*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
60*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
61*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
62*da0073e9SAndroid Build Coastguard Worker        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
63*da0073e9SAndroid Build Coastguard Worker        conv_bias_shape = (output_channels)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        class Conv2D(torch.nn.Module):
66*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
67*da0073e9SAndroid Build Coastguard Worker                super().__init__()
68*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
69*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
70*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
71*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
72*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
73*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
76*da0073e9SAndroid Build Coastguard Worker                return F.conv2d(x, self.weight, self.bias,
77*da0073e9SAndroid Build Coastguard Worker                                self.strides, self.paddings, self.dilations, self.groups)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        data_shape = (batch_size, input_channels, height, width)
80*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {"Tensor = aten::conv2d": -1,
81*da0073e9SAndroid Build Coastguard Worker                             "metal_prepack::conv2d_prepack": 1,
82*da0073e9SAndroid Build Coastguard Worker                             "metal_prepack::conv2d_run": 1}
83*da0073e9SAndroid Build Coastguard Worker        TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        class Conv2DRelu(torch.nn.Module):
86*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
87*da0073e9SAndroid Build Coastguard Worker                super().__init__()
88*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
89*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
90*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
91*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
92*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
93*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
96*da0073e9SAndroid Build Coastguard Worker                o = F.conv2d(x, self.weight, self.bias,
97*da0073e9SAndroid Build Coastguard Worker                             self.strides, self.paddings, self.dilations, self.groups)
98*da0073e9SAndroid Build Coastguard Worker                o = F.relu(o)
99*da0073e9SAndroid Build Coastguard Worker                return o
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        data_shape = (batch_size, input_channels, height, width)
102*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {"Tensor = aten::conv2d": -1,
103*da0073e9SAndroid Build Coastguard Worker                             "metal_prepack::conv2d_prepack": 1,
104*da0073e9SAndroid Build Coastguard Worker                             "metal_prepack::conv2d_run": 1}
105*da0073e9SAndroid Build Coastguard Worker        TestMetalRewritePass.validate_transformed_module(
106*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(), pattern_count_map, data_shape)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::relu"] = 1
109*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["metal_prepack::conv2d_prepack"] = -1
110*da0073e9SAndroid Build Coastguard Worker        TestMetalRewritePass.validate_transformed_module(
111*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(),
112*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
113*da0073e9SAndroid Build Coastguard Worker            data_shape,
114*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True)
115*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::relu"] = -1
116*da0073e9SAndroid Build Coastguard Worker        TestMetalRewritePass.validate_transformed_module(
117*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(),
118*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
119*da0073e9SAndroid Build Coastguard Worker            data_shape,
120*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
121*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        class Conv2DHardtanh(torch.nn.Module):
125*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
126*da0073e9SAndroid Build Coastguard Worker                super().__init__()
127*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
128*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
129*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
130*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
131*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
132*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
135*da0073e9SAndroid Build Coastguard Worker                o = F.conv2d(x, self.weight, self.bias,
136*da0073e9SAndroid Build Coastguard Worker                             self.strides, self.paddings, self.dilations, self.groups)
137*da0073e9SAndroid Build Coastguard Worker                o = F.hardtanh(o)
138*da0073e9SAndroid Build Coastguard Worker                return o
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        data_shape = (batch_size, input_channels, height, width)
141*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {"Tensor = aten::conv2d": -1,
142*da0073e9SAndroid Build Coastguard Worker                             "metal_prepack::conv2d_prepack": 1,
143*da0073e9SAndroid Build Coastguard Worker                             "metal_prepack::conv2d_run": 1}
144*da0073e9SAndroid Build Coastguard Worker        TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
145*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::hardtanh"] = 1
146*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["metal_prepack::conv2d_prepack"] = -1
147*da0073e9SAndroid Build Coastguard Worker        TestMetalRewritePass.validate_transformed_module(
148*da0073e9SAndroid Build Coastguard Worker            Conv2DHardtanh(),
149*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
150*da0073e9SAndroid Build Coastguard Worker            data_shape,
151*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True)
152*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::hardtanh"] = -1
153*da0073e9SAndroid Build Coastguard Worker        TestMetalRewritePass.validate_transformed_module(
154*da0073e9SAndroid Build Coastguard Worker            Conv2DRelu(),
155*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
156*da0073e9SAndroid Build Coastguard Worker            data_shape,
157*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
158*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True)
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
161*da0073e9SAndroid Build Coastguard Worker    run_tests()
162