xref: /aosp_15_r20/external/pytorch/test/test_mobile_optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: mobile"]
2
3import unittest
4import torch
5import torch.nn as nn
6import torch.utils.bundled_inputs
7from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK
8from torch.testing._internal.jit_utils import get_forward, get_forward_graph
9from torch.utils.mobile_optimizer import (LintCode,
10                                          generate_mobile_module_lints,
11                                          optimize_for_mobile,
12                                          MobileOptimizerType)
13from torch.nn import functional as F
14from torch.testing._internal.common_quantized import override_quantized_engine
15
16try:
17    import torchvision
18    HAS_TORCHVISION = True
19except ImportError:
20    HAS_TORCHVISION = False
21
22FileCheck = torch._C.FileCheck
23
24class TestOptimizer(TestCase):
25
26    @skipIfNoXNNPACK
27    def test_optimize_for_mobile(self):
28        batch_size = 2
29        input_channels_per_group = 6
30        height = 16
31        width = 16
32        output_channels_per_group = 6
33        groups = 4
34        kernel_h = kernel_w = 3
35        stride_h = stride_w = 1
36        pad_h = pad_w = 1
37        dilation = 1
38        input_channels = input_channels_per_group * groups
39        output_channels = output_channels_per_group * groups
40        kernels = (kernel_h, kernel_w)
41        strides = (stride_h, stride_w)
42        paddings = (pad_h, pad_w)
43        dilations = (dilation, dilation)
44        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
45        conv_bias_shape = (output_channels)
46
47        input_data = torch.rand((batch_size, input_channels, height, width))
48        conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
49        conv_bias = torch.rand(output_channels)
50        result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
51        weight_output_dim = 24
52        linear_input_shape = result.shape[1]
53        linear_weight_shape = (weight_output_dim, linear_input_shape)
54
55        class MyTestModule(torch.nn.Module):
56            def __init__(self) -> None:
57                super().__init__()
58                self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape))
59                self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape))
60                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
61                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
62                self.strides = strides
63                self.paddings = paddings
64                self.dilations = dilations
65                self.groups = groups
66
67            def forward(self, x):
68                o = F.conv2d(x, self.conv_weight, self.conv_bias,
69                             self.strides, self.paddings, self.dilations, self.groups)
70                o = F.relu(o)
71                x = o.permute([0, 2, 3, 1])
72                o = F.linear(x, self.linear_weight, self.linear_bias)
73                o = o + x
74                return F.relu(o)
75
76            @torch.jit.export
77            def foo(self, x):
78                o = F.conv2d(x, self.conv_weight, self.conv_bias,
79                             self.strides, self.paddings, self.dilations, self.groups)
80                o = F.relu(o)
81                x = o.permute([0, 2, 3, 1])
82                o = F.linear(x, self.linear_weight, self.linear_bias)
83                o = o + x
84                return F.relu(o)
85
86
87        class BNTestModule(torch.nn.Module):
88            def __init__(self) -> None:
89                super().__init__()
90                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
91                self.bn = torch.nn.BatchNorm2d(num_features=20)
92                self.bn.eps = 0.0023
93
94            def forward(self, x):
95                x = self.conv(x)
96                x = self.bn(x)
97                return x
98
99        data_shape = (batch_size, input_channels, height, width)
100        input_data = torch.normal(1, 20, size=data_shape)
101
102        scripted_model = torch.jit.script(MyTestModule())
103        scripted_model.eval()
104        initial_result = scripted_model(input_data)
105        initial_foo_result = scripted_model.foo(input_data)
106
107        optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo'])
108        optimized_result = optimized_scripted_model(input_data)
109        optimized_foo_result = optimized_scripted_model.foo(input_data)
110
111        FileCheck().check_not("Tensor = aten::conv2d") \
112                   .check_not("Tensor = prim::CallFunction") \
113                   .check_not("prepacked::conv2d_clamp_prepack") \
114                   .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
115                   .check_not("prepacked::linear_clamp_prepack") \
116                   .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
117                   .check_not("aten::add(") \
118                   .check_not("aten::relu(") \
119                   .check_count("aten::_add_relu(", 1, exactly=True) \
120                   .run(optimized_scripted_model.graph)
121        torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
122
123        FileCheck().check_not("Tensor = aten::conv2d") \
124                   .check_not("Tensor = prim::CallFunction") \
125                   .check_not("prepacked::conv2d_clamp_prepack") \
126                   .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
127                   .check_not("prepacked::linear_clamp_prepack") \
128                   .check_count("prepacked::linear_clamp_run", 1, exactly=True) \
129                   .check_not("aten::add(") \
130                   .check_not("aten::relu(") \
131                   .check_count("aten::_add_relu(", 1, exactly=True) \
132                   .run(optimized_scripted_model.foo.graph)
133        torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
134
135
136        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
137        optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
138        optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data)
139
140        FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \
141                   .check_not("prepacked::linear_clamp_run") \
142                   .check_not("prepacked::conv2d_clamp_run") \
143                   .run(optimized_scripted_model_no_prepack.graph)
144        torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3)
145
146
147        bn_test_module = BNTestModule()
148        bn_scripted_module = torch.jit.script(bn_test_module)
149        bn_scripted_module.eval()
150
151        self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
152        FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
153                   .run(str(get_forward(bn_scripted_module._c).graph))
154
155        optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
156        bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack)
157        self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
158        bn_input = torch.rand(1, 1, 6, 6)
159        torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
160
161        optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION}
162        no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn)
163        FileCheck().check_count("aten::batch_norm", 1, exactly=True) \
164                   .run(str(get_forward_graph(no_bn_fold_scripted_module._c)))
165        bn_input = torch.rand(1, 1, 6, 6)
166        torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
167
168        class MyMobileOptimizedTagTest(torch.nn.Module):
169            def __init__(self) -> None:
170                super().__init__()
171                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
172                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
173
174            def forward(self, x):
175                o = F.linear(x, self.linear_weight, self.linear_bias)
176                return F.relu(o)
177
178        mobile_optimized_tag_module = MyMobileOptimizedTagTest()
179        m = torch.jit.script(mobile_optimized_tag_module)
180        m.eval()
181        opt_m = optimize_for_mobile(m)
182        tag = getattr(opt_m, "mobile_optimized", None)
183        self.assertTrue(tag)
184
185        class MyPreserveMethodsTest(torch.nn.Module):
186            def __init__(self) -> None:
187                super().__init__()
188                self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape))
189                self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim))
190
191            def forward(self, x):
192                o = F.linear(x, self.linear_weight, self.linear_bias)
193                return F.relu(o)
194
195            @torch.jit.export
196            def preserveThis(self):
197                pass
198
199        preserve_method_module = MyPreserveMethodsTest()
200        m = torch.jit.script(preserve_method_module)
201        m.eval()
202        opt_m = optimize_for_mobile(m)
203        no_preserveThis = getattr(opt_m, "preserveThis", None)
204        self.assertEqual(no_preserveThis, None)
205        opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
206        preserveThis = getattr(opt_m, "preserveThis", None)
207        self.assertNotEqual(preserveThis, None)
208
209        class OptimizeNoForwardTest(torch.nn.Module):
210            def __init__(self) -> None:
211                super().__init__()
212                self.l = nn.Linear(10, 100)
213                self.l2 = nn.Linear(100, 1)
214                self.d = nn.Dropout(p=0.2)
215
216            @torch.jit.export
217            def foo(self, x):
218                x = self.d(F.relu(self.l(x)))
219                x = self.l2(x)
220                x = x + torch.ones(1, 100)
221                return F.relu(x)
222        input_data = torch.ones(1, 10)
223        m = torch.jit.script(OptimizeNoForwardTest())
224        m.eval()
225        initial_result = m.foo(input_data)
226
227        optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo'])
228        optimized_result = optimized_scripted_model.foo(input_data)
229
230        FileCheck().check_not("dropout.__") \
231            .check_count("aten::_add_relu(", 1, exactly=True) \
232            .run(optimized_scripted_model.foo.graph)
233        torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
234
235        class BNTestNoForwardModule(torch.nn.Module):
236            def __init__(self) -> None:
237                super().__init__()
238                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
239                self.bn = torch.nn.BatchNorm2d(num_features=20)
240                self.bn.eps = 0.0023
241
242            @torch.jit.export
243            def foo(self, x):
244                x = self.conv(x)
245                x = self.bn(x)
246                return x
247
248        bn_test_no_forward_module = BNTestNoForwardModule()
249        bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
250        bn_no_forward_scripted_module.eval()
251
252        self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
253        FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \
254                   .run(bn_no_forward_scripted_module.foo.graph)
255
256        bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo'])
257        self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1)
258        bn_input = torch.rand(1, 1, 6, 6)
259        torch.testing.assert_close(
260            bn_no_forward_scripted_module.foo(bn_input),
261            bn_fold_no_forward_scripted_module.foo(bn_input),
262            rtol=1e-2,
263            atol=1e-3)
264
265    @skipIfNoXNNPACK
266    def test_quantized_conv_no_asan_failures(self):
267        # There were ASAN failures when fold_conv_bn was run on
268        # already quantized conv modules. Verifying that this does
269        # not happen again.
270
271        if 'qnnpack' not in torch.backends.quantized.supported_engines:
272            return
273
274        class Child(nn.Module):
275            def __init__(self) -> None:
276                super().__init__()
277                self.conv2 = nn.Conv2d(1, 1, 1)
278
279            def forward(self, x):
280                x = self.conv2(x)
281                return x
282
283        class Parent(nn.Module):
284            def __init__(self) -> None:
285                super().__init__()
286                self.quant = torch.ao.quantization.QuantStub()
287                self.conv1 = nn.Conv2d(1, 1, 1)
288                self.child = Child()
289                self.dequant = torch.ao.quantization.DeQuantStub()
290
291            def forward(self, x):
292                x = self.quant(x)
293                x = self.conv1(x)
294                x = self.child(x)
295                x = self.dequant(x)
296                return x
297
298        with override_quantized_engine('qnnpack'):
299            model = Parent()
300            model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
301            torch.ao.quantization.prepare(model, inplace=True)
302            model(torch.randn(4, 1, 4, 4))
303            torch.ao.quantization.convert(model, inplace=True)
304            model = torch.jit.script(model)
305            # this line should not have ASAN failures
306            model_optim = optimize_for_mobile(model)
307
308    def test_generate_mobile_module_lints(self):
309        class MyTestModule(torch.nn.Module):
310            def __init__(self) -> None:
311                super().__init__()
312                self.fc = torch.nn.Linear(4, 4)
313                self.dropout = torch.nn.Dropout(p=0.5)
314
315            def forward(self, inputs):
316                out = self.fc(inputs)
317                out = self.dropout(out)
318                return out
319
320        class MyBNModule(torch.nn.Module):
321            def __init__(self) -> None:
322                super().__init__()
323                self.bn = torch.nn.BatchNorm2d(4, affine=True)
324
325            def forward(self, inputs):
326                bn = self.bn(inputs)
327                return bn
328
329        class MyBundledInputModule(torch.nn.Module):
330            def forward(self, inputs):
331                return inputs
332
333        def get_lint_count_by_type(lint_type, module_lint_List):
334            return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name])
335
336        test_module = torch.jit.script(MyTestModule())
337        test_module_lint_list = generate_mobile_module_lints(test_module)
338        self.assertEqual(len(test_module_lint_list), 4)
339        self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1)
340        self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1)
341        self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2)
342
343        bn_module = torch.jit.script(MyBNModule())
344        bn_module_lint_list = generate_mobile_module_lints(bn_module)
345        self.assertEqual(len(bn_module_lint_list), 4)
346        self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1)
347        self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1)
348        self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2)
349
350        bi_module = torch.jit.script(MyBundledInputModule())
351        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
352            bi_module, [(torch.tensor([1]),)], [])
353        bi_module_lint_list = generate_mobile_module_lints(bi_module)
354        self.assertEqual(len(bi_module_lint_list), 0)
355
356    @skipIfNoXNNPACK
357    def test_preserve_bundled_inputs_methods(self):
358        class MyBundledInputModule(torch.nn.Module):
359            def forward(self, inputs):
360                return inputs
361
362        class MyIncompleteBundledInputModule(torch.nn.Module):
363            def forward(self, inputs):
364                return inputs
365
366            @torch.jit.export
367            def get_all_bundled_inputs(self):
368                pass
369
370        bi_module = torch.jit.script(MyBundledInputModule())
371        module_optim_bi_not_preserved = optimize_for_mobile(bi_module)
372
373        # Expected to be False since no bundled inputs methods were added
374        self.assertFalse(
375            hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
376            hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs')
377        )
378
379        # Add bundled inputs methods to the module
380        torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
381            bi_module, [(torch.tensor([1]),)], [])
382        # Now they should be preserved
383        module_optim_bi_preserved = optimize_for_mobile(bi_module)
384
385        # All of the bundled inputs methods were preserved
386        self.assertTrue(
387            hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
388            hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs')
389        )
390
391        bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
392        module_optim_bi_preserved(*bundled_input)
393
394        # If not all 3 bundled inputs methods are present in the module,
395        # we will not try to preserve them unless specified by the user.
396        incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
397        incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
398        self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
399
400        # Specifically preserve get_all_bundled_inputs even if it's the only one
401        # bundled inputs method available.
402        incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
403        self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
404
405    @skipIfNoXNNPACK
406    def test_hoist_conv_packed_params(self):
407
408        if 'qnnpack' not in torch.backends.quantized.supported_engines:
409            return
410
411        class Standalone(nn.Module):
412            def __init__(self) -> None:
413                super().__init__()
414                self.quant = torch.ao.quantization.QuantStub()
415                self.conv1 = nn.Conv2d(1, 1, 1)
416                self.conv2 = nn.Conv2d(1, 1, 1)
417                self.relu = nn.ReLU()
418                self.dequant = torch.ao.quantization.DeQuantStub()
419
420            def forward(self, x):
421                x = self.quant(x)
422                x = self.conv1(x)
423                x = self.conv2(x)
424                x = self.relu(x)
425                x = self.dequant(x)
426                return x
427
428            def fuse_model(self):
429                torch.ao.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
430
431        class Child(nn.Module):
432            def __init__(self) -> None:
433                super().__init__()
434                self.conv1 = nn.Conv2d(1, 1, 1)
435
436            def forward(self, x):
437                x = self.conv1(x)
438                return x
439
440        class Parent(nn.Module):
441            def __init__(self) -> None:
442                super().__init__()
443                self.quant = torch.ao.quantization.QuantStub()
444                self.conv1 = nn.Conv2d(1, 1, 1)
445                self.child = Child()
446                # TODO: test nn.Sequential after #42039 is fixed
447                self.dequant = torch.ao.quantization.DeQuantStub()
448
449            def forward(self, x):
450                x = self.quant(x)
451                x = self.conv1(x)
452                x = self.child(x)
453                x = self.dequant(x)
454                return x
455
456            def fuse_model(self):
457                pass
458
459        with override_quantized_engine('qnnpack'):
460            def _quant_script_and_optimize(model):
461                model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
462                model.fuse_model()
463                torch.ao.quantization.prepare(model, inplace=True)
464                model(torch.randn(4, 1, 4, 4))
465                torch.ao.quantization.convert(model, inplace=True)
466                model = torch.jit.script(model)
467                model_optim = optimize_for_mobile(model)
468                return model, model_optim
469
470            # basic case
471
472            m, m_optim = _quant_script_and_optimize(Standalone())
473            FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
474                       .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
475                       .run(m_optim.graph)
476            self.assertFalse(hasattr(m_optim, "conv1"))
477            self.assertFalse(hasattr(m_optim, "conv2"))
478
479            data = torch.randn(4, 1, 4, 4)
480            m_res = m(data)
481            m_optim_res = m_optim(data)
482            torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
483
484            # generic case
485
486            m, m_optim = _quant_script_and_optimize(Parent())
487            FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \
488                       .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
489                       .run(m_optim.graph)
490            self.assertFalse(hasattr(m_optim, "conv1"))
491            self.assertFalse(hasattr(m_optim, "child"))
492
493            data = torch.randn(4, 1, 4, 4)
494            m_res = m(data)
495            m_optim_res = m_optim(data)
496            torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
497
498    @skipIfNoXNNPACK
499    @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision")
500    def test_mobilenet_optimize_for_mobile(self):
501        m = torchvision.models.mobilenet_v3_small()
502        m = torch.jit.script(m)
503        m = optimize_for_mobile(m)
504
505        # run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463
506        x = torch.zeros(1, 3, 56, 56)
507        self.assertEqual(m(x).numel(), 1000)
508        self.assertEqual(m(x).numel(), 1000)
509        self.assertEqual(m(x).numel(), 1000)
510
511    def test_clone_module_with_class(self):
512        class MyInnerTestModule(torch.nn.Module):
513            def __init__(self) -> None:
514                super().__init__()
515                self.pqr = torch.Tensor([10., 20., 30.])
516
517            def forward(self, inputs):
518                return inputs
519
520            @torch.jit.export
521            def dummy_method_not_cloned(self):
522                return 20
523
524        class MyTestModule(torch.nn.Module):
525            def __init__(self) -> None:
526                super().__init__()
527                self.abc = 23
528                self.pqr = torch.Tensor([1., 2., 3.])
529                self.inner = MyInnerTestModule()
530
531            def forward(self, inputs):
532                x = self.dummy_method_cloned()
533                # The call to self.inner.dummy_method_not_cloned should not raise an error
534                y = self.inner.dummy_method_not_cloned()
535                # The call to self.inner.pqr should not raise an error
536                z = self.inner.pqr
537                return (inputs, x, y, z)
538
539            @torch.jit.export
540            def dummy_method_not_cloned2(self):
541                # The call to self.inner.dummy_method_not_cloned should not raise an error
542                y = self.inner.dummy_method_not_cloned()
543                # The call to self.inner.pqr should not raise an error
544                z = self.inner.pqr
545                return self.pqr, self.dummy_method_not_cloned(), y, z
546
547            @torch.jit.export
548            def dummy_method_not_cloned(self):
549                return None
550
551            @torch.jit.export
552            def dummy_method_cloned(self):
553                return None
554
555            @torch.jit.export
556            def dummy_method_ref_attr_pqr(self):
557                return self.pqr, self.inner.pqr
558
559        m = torch.jit.script(MyTestModule())
560
561        # Check that the methods exist on the original model.
562        self.assertEqual(hasattr(m, "dummy_method_not_cloned"), True)
563        self.assertEqual(hasattr(m, "dummy_method_cloned"), True)
564        self.assertEqual(hasattr(m, "dummy_method_not_cloned2"), True)
565        self.assertEqual(hasattr(m, "pqr"), True)
566
567        # Case-1: Successfully clone, ignoring 2 methods, keeping all attributes.
568        cloned = torch._C._hack_do_not_use_clone_module_with_class(
569            m._c,
570            ["dummy_method_not_cloned", "dummy_method_not_cloned2"],  # ignored_methods
571            [],  # ignored_attributes
572        )
573
574        # Check that the ignored methods don't exist on the cloned model.
575        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
576        self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
577        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
578        self.assertEqual(hasattr(cloned, "pqr"), True)
579
580        # Check that the cloned class has a classname that starts with __torch__.
581        self.assertTrue(
582            cloned.qualified_name.startswith('__torch__.'),
583            ("Expected the cloned module's name to start with the string "
584             f"'__torch__.', but got: {cloned.qualified_name}"),
585        )
586
587
588        # Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it.
589        cloned = torch._C._hack_do_not_use_clone_module_with_class(
590            m._c,
591            ["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"],
592            ["pqr"],
593        )
594
595        # Check that the ignored methods don't exist on the cloned model.
596        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False)
597        self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True)
598        self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False)
599        self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False)
600        self.assertEqual(hasattr(cloned, "pqr"), False)
601
602
603        # Case-3: The statement below will throw since dummy_method_cloned2 is preserved,
604        # and references dummy_method_not_cloned, which is not cloned.
605        with self.assertRaises(RuntimeError):
606            cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], [])
607
608        # Case-4: The statement below will throw since dummy_method_ref_attr_pqr
609        # is preserved, and references "pqr", which is not cloned.
610        with self.assertRaises(RuntimeError):
611            cloned = torch._C._hack_do_not_use_clone_module_with_class(
612                m._c,
613                ["dummy_method_not_cloned", "dummy_method_not_cloned2"],
614                ["pqr"],
615            )
616
617
618if __name__ == '__main__':
619    run_tests()
620