xref: /aosp_15_r20/external/pytorch/test/jit/test_optimize_for_mobile_preserve_debug_info.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: mobile"]
2
3import torch
4import torch._C
5import torch.nn.functional as F
6from torch.testing._internal.common_utils import skipIfNoXNNPACK
7from torch.testing._internal.jit_utils import JitTestCase
8
9
10class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
11    def check_replacement(
12        self,
13        model,
14        replacements,
15        jit_pass,
16    ):
17        """
18        model: Model which optimization is performed on
19        replacements: Dict mapping from nodes' kinds in the optimized model
20            to the kinds of nodes they replaced in the original model
21        jit_pass: Function to perform optimization
22        """
23
24        original_kinds = set(replacements.values())
25        original_source_ranges = {
26            node.kind(): node.sourceRange()
27            for node in model.graph.nodes()
28            if node.kind() in original_kinds
29        }
30
31        jit_pass(model._c)
32
33        for node in model.graph.nodes():
34            if node.kind() in replacements:
35                self.assertEqual(
36                    node.sourceRange(),
37                    original_source_ranges[replacements[node.kind()]],
38                )
39
40    @skipIfNoXNNPACK
41    def test_replace_conv1d_with_conv2d(self):
42        class TestConv1d(torch.nn.Module):
43            def __init__(self, weight, bias):
44                super().__init__()
45                self.weight = weight
46                self.bias = bias
47
48            def forward(self, x):
49                return F.conv1d(x, self.weight, self.bias)
50
51        self.check_replacement(
52            model=torch.jit.script(
53                TestConv1d(
54                    weight=torch.rand(3, 3, 3),
55                    bias=torch.rand(3),
56                ),
57            ),
58            replacements={
59                "prim::ListUnpack": "aten::conv1d",
60                "prim::ListConstruct": "aten::conv1d",
61                "aten::unsqueeze": "aten::conv1d",
62                "aten::conv2d": "aten::conv1d",
63                "aten::squeeze": "aten::conv1d",
64            },
65            jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d,
66        )
67
68    @skipIfNoXNNPACK
69    def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self):
70        class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module):
71            def __init__(
72                self,
73                linear_weight,
74                linear_bias,
75                conv2d_weight,
76                conv2d_bias,
77                conv_transpose2d_weight,
78                conv_transpose2d_bias,
79            ):
80                super(
81                    TestPrepackedLinearBeforeInlineAndConv2dOp,
82                    self,
83                ).__init__()
84                self.linear_weight = linear_weight.float()
85                self.linear_bias = linear_bias.float()
86                self.conv2d_weight = conv2d_weight.float()
87                self.conv2d_bias = conv2d_bias.float()
88                self.conv_transpose2d_weight = conv_transpose2d_weight.float()
89                self.conv_transpose2d_bias = conv_transpose2d_bias.float()
90
91            def forward(self, x):
92                linear_res = F.linear(
93                    x.float(),
94                    self.linear_weight,
95                    self.linear_bias,
96                )
97                conv2d_res = F.conv2d(
98                    input=linear_res.unsqueeze(dim=0).float(),
99                    weight=self.conv2d_weight,
100                    bias=self.conv2d_bias,
101                )
102                return F.conv_transpose2d(
103                    input=conv2d_res,
104                    weight=self.conv_transpose2d_weight,
105                    bias=self.conv_transpose2d_bias,
106                )
107
108        minibatch = 1
109        in_channels = 6
110        iH = 4
111        iW = 5
112        out_channels = 6
113        kH = 2
114        kW = 3
115
116        self.check_replacement(
117            model=torch.jit.script(
118                TestPrepackedLinearBeforeInlineAndConv2dOp(
119                    linear_weight=torch.rand(iW, 3),
120                    linear_bias=torch.rand(iW),
121                    conv2d_weight=torch.rand(out_channels, in_channels, kH, kW),
122                    conv2d_bias=torch.rand(out_channels),
123                    conv_transpose2d_weight=torch.rand(
124                        out_channels,
125                        in_channels,
126                        kH,
127                        kW,
128                    ),
129                    conv_transpose2d_bias=torch.rand(out_channels),
130                ),
131            ),
132            replacements={
133                "prepacked::linear_clamp_prepack": "aten::linear",
134                "prepacked::linear_clamp_run": "aten::linear",
135                "prepacked::conv2d_clamp_prepack": "aten::conv2d",
136                "prepacked::conv2d_clamp_run": "aten::conv2d",
137                "prepacked::conv2d_transpose_clamp_prepack": "aten::conv_transpose2d",
138                "prepacked::conv2d_transpose_clamp_run": "aten::conv_transpose2d",
139            },
140            jit_pass=torch._C._jit_pass_insert_prepacked_ops,
141        )
142
143    @skipIfNoXNNPACK
144    def test_insert_pre_packed_linear_op(self):
145        self.check_replacement(
146            model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
147            replacements={
148                "prepacked::linear_clamp_prepack": "aten::linear",
149                "prepacked::linear_clamp_run": "aten::linear",
150            },
151            jit_pass=torch._C._jit_pass_insert_prepacked_ops,
152        )
153
154    def run_test_fuse_activation_with_pack_ops_linear_conv2d(
155        self,
156        linear_activation,
157        linear_activation_kind,
158        conv2d_activation,
159        conv2d_activation_kind,
160    ):
161        class TestFuseActivationLinearConv2d(torch.nn.Module):
162            def __init__(
163                self,
164                linear_weight,
165                linear_bias,
166                conv2d_weight,
167                conv2d_bias,
168            ):
169                super().__init__()
170                self.linear_weight = linear_weight
171                self.linear_bias = linear_bias
172                self.conv2d_weight = conv2d_weight
173                self.conv2d_bias = conv2d_bias
174
175            def forward(self, x):
176                x = F.linear(
177                    input=x,
178                    weight=self.linear_weight,
179                    bias=self.linear_bias,
180                )
181                x = linear_activation(x)
182                x = F.conv2d(
183                    input=x.unsqueeze(dim=0),
184                    weight=self.conv2d_weight,
185                    bias=self.conv2d_bias,
186                )
187                return conv2d_activation(x)
188
189        linear_in_features = 5
190        linear_out_features = 4
191        conv2d_in_channels = 3
192        conv2d_out_channels = 4
193        conv2d_kernel = 2
194        x_shape = (3, 2, 5)
195
196        model = torch.jit.trace(
197            TestFuseActivationLinearConv2d(
198                linear_weight=torch.nn.Parameter(
199                    data=torch.rand(
200                        linear_out_features,
201                        linear_in_features,
202                    ),
203                    requires_grad=False,
204                ),
205                linear_bias=torch.nn.Parameter(
206                    data=torch.rand(linear_out_features),
207                    requires_grad=False,
208                ),
209                conv2d_weight=torch.rand(
210                    conv2d_out_channels,
211                    conv2d_in_channels,
212                    conv2d_kernel,
213                    conv2d_kernel,
214                ),
215                conv2d_bias=torch.rand(conv2d_out_channels),
216            ),
217            torch.rand(x_shape),
218        )
219
220        torch._C._jit_pass_insert_prepacked_ops(model._c)
221
222        self.check_replacement(
223            model=model,
224            replacements={
225                "prepacked::linear_clamp_prepack": "prepacked::linear_clamp_prepack",
226                "prepacked::linear_clamp_run": linear_activation_kind,
227                "prepacked::conv2d_clamp_prepack": "prepacked::conv2d_clamp_prepack",
228                "prepacked::conv2d_clamp_run": conv2d_activation_kind,
229            },
230            jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
231        )
232
233    @skipIfNoXNNPACK
234    def test_fuse_activation_with_pack_ops_linear_conv2d_1(self):
235        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
236            linear_activation=F.hardtanh,
237            linear_activation_kind="aten::hardtanh",
238            conv2d_activation=F.hardtanh_,
239            conv2d_activation_kind="aten::hardtanh_",
240        )
241
242    @skipIfNoXNNPACK
243    def test_fuse_activation_with_pack_ops_linear_conv2d_2(self):
244        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
245            linear_activation=F.hardtanh_,
246            linear_activation_kind="aten::hardtanh_",
247            conv2d_activation=F.hardtanh,
248            conv2d_activation_kind="aten::hardtanh",
249        )
250
251    @skipIfNoXNNPACK
252    def test_fuse_activation_with_pack_ops_linear_conv2d_3(self):
253        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
254            linear_activation=F.relu,
255            linear_activation_kind="aten::relu",
256            conv2d_activation=F.relu_,
257            conv2d_activation_kind="aten::relu_",
258        )
259
260    @skipIfNoXNNPACK
261    def test_fuse_activation_with_pack_ops_linear_conv2d_4(self):
262        self.run_test_fuse_activation_with_pack_ops_linear_conv2d(
263            linear_activation=F.relu_,
264            linear_activation_kind="aten::relu_",
265            conv2d_activation=F.relu,
266            conv2d_activation_kind="aten::relu",
267        )
268