1# Owner(s): ["oncall: mobile"] 2 3import unittest 4import torch 5import torch.nn as nn 6import torch.utils.bundled_inputs 7from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK 8from torch.testing._internal.jit_utils import get_forward, get_forward_graph 9from torch.utils.mobile_optimizer import (LintCode, 10 generate_mobile_module_lints, 11 optimize_for_mobile, 12 MobileOptimizerType) 13from torch.nn import functional as F 14from torch.testing._internal.common_quantized import override_quantized_engine 15 16try: 17 import torchvision 18 HAS_TORCHVISION = True 19except ImportError: 20 HAS_TORCHVISION = False 21 22FileCheck = torch._C.FileCheck 23 24class TestOptimizer(TestCase): 25 26 @skipIfNoXNNPACK 27 def test_optimize_for_mobile(self): 28 batch_size = 2 29 input_channels_per_group = 6 30 height = 16 31 width = 16 32 output_channels_per_group = 6 33 groups = 4 34 kernel_h = kernel_w = 3 35 stride_h = stride_w = 1 36 pad_h = pad_w = 1 37 dilation = 1 38 input_channels = input_channels_per_group * groups 39 output_channels = output_channels_per_group * groups 40 kernels = (kernel_h, kernel_w) 41 strides = (stride_h, stride_w) 42 paddings = (pad_h, pad_w) 43 dilations = (dilation, dilation) 44 conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) 45 conv_bias_shape = (output_channels) 46 47 input_data = torch.rand((batch_size, input_channels, height, width)) 48 conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w)) 49 conv_bias = torch.rand(output_channels) 50 result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups) 51 weight_output_dim = 24 52 linear_input_shape = result.shape[1] 53 linear_weight_shape = (weight_output_dim, linear_input_shape) 54 55 class MyTestModule(torch.nn.Module): 56 def __init__(self) -> None: 57 super().__init__() 58 self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape)) 59 self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape)) 60 self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) 61 self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim)) 62 self.strides = strides 63 self.paddings = paddings 64 self.dilations = dilations 65 self.groups = groups 66 67 def forward(self, x): 68 o = F.conv2d(x, self.conv_weight, self.conv_bias, 69 self.strides, self.paddings, self.dilations, self.groups) 70 o = F.relu(o) 71 x = o.permute([0, 2, 3, 1]) 72 o = F.linear(x, self.linear_weight, self.linear_bias) 73 o = o + x 74 return F.relu(o) 75 76 @torch.jit.export 77 def foo(self, x): 78 o = F.conv2d(x, self.conv_weight, self.conv_bias, 79 self.strides, self.paddings, self.dilations, self.groups) 80 o = F.relu(o) 81 x = o.permute([0, 2, 3, 1]) 82 o = F.linear(x, self.linear_weight, self.linear_bias) 83 o = o + x 84 return F.relu(o) 85 86 87 class BNTestModule(torch.nn.Module): 88 def __init__(self) -> None: 89 super().__init__() 90 self.conv = torch.nn.Conv2d(1, 20, 5, 1) 91 self.bn = torch.nn.BatchNorm2d(num_features=20) 92 self.bn.eps = 0.0023 93 94 def forward(self, x): 95 x = self.conv(x) 96 x = self.bn(x) 97 return x 98 99 data_shape = (batch_size, input_channels, height, width) 100 input_data = torch.normal(1, 20, size=data_shape) 101 102 scripted_model = torch.jit.script(MyTestModule()) 103 scripted_model.eval() 104 initial_result = scripted_model(input_data) 105 initial_foo_result = scripted_model.foo(input_data) 106 107 optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo']) 108 optimized_result = optimized_scripted_model(input_data) 109 optimized_foo_result = optimized_scripted_model.foo(input_data) 110 111 FileCheck().check_not("Tensor = aten::conv2d") \ 112 .check_not("Tensor = prim::CallFunction") \ 113 .check_not("prepacked::conv2d_clamp_prepack") \ 114 .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ 115 .check_not("prepacked::linear_clamp_prepack") \ 116 .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ 117 .check_not("aten::add(") \ 118 .check_not("aten::relu(") \ 119 .check_count("aten::_add_relu(", 1, exactly=True) \ 120 .run(optimized_scripted_model.graph) 121 torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3) 122 123 FileCheck().check_not("Tensor = aten::conv2d") \ 124 .check_not("Tensor = prim::CallFunction") \ 125 .check_not("prepacked::conv2d_clamp_prepack") \ 126 .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ 127 .check_not("prepacked::linear_clamp_prepack") \ 128 .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ 129 .check_not("aten::add(") \ 130 .check_not("aten::relu(") \ 131 .check_count("aten::_add_relu(", 1, exactly=True) \ 132 .run(optimized_scripted_model.foo.graph) 133 torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3) 134 135 136 optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} 137 optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack) 138 optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data) 139 140 FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \ 141 .check_not("prepacked::linear_clamp_run") \ 142 .check_not("prepacked::conv2d_clamp_run") \ 143 .run(optimized_scripted_model_no_prepack.graph) 144 torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3) 145 146 147 bn_test_module = BNTestModule() 148 bn_scripted_module = torch.jit.script(bn_test_module) 149 bn_scripted_module.eval() 150 151 self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11) 152 FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ 153 .run(str(get_forward(bn_scripted_module._c).graph)) 154 155 optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} 156 bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack) 157 self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1) 158 bn_input = torch.rand(1, 1, 6, 6) 159 torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) 160 161 optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION} 162 no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn) 163 FileCheck().check_count("aten::batch_norm", 1, exactly=True) \ 164 .run(str(get_forward_graph(no_bn_fold_scripted_module._c))) 165 bn_input = torch.rand(1, 1, 6, 6) 166 torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) 167 168 class MyMobileOptimizedTagTest(torch.nn.Module): 169 def __init__(self) -> None: 170 super().__init__() 171 self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) 172 self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim)) 173 174 def forward(self, x): 175 o = F.linear(x, self.linear_weight, self.linear_bias) 176 return F.relu(o) 177 178 mobile_optimized_tag_module = MyMobileOptimizedTagTest() 179 m = torch.jit.script(mobile_optimized_tag_module) 180 m.eval() 181 opt_m = optimize_for_mobile(m) 182 tag = getattr(opt_m, "mobile_optimized", None) 183 self.assertTrue(tag) 184 185 class MyPreserveMethodsTest(torch.nn.Module): 186 def __init__(self) -> None: 187 super().__init__() 188 self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) 189 self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim)) 190 191 def forward(self, x): 192 o = F.linear(x, self.linear_weight, self.linear_bias) 193 return F.relu(o) 194 195 @torch.jit.export 196 def preserveThis(self): 197 pass 198 199 preserve_method_module = MyPreserveMethodsTest() 200 m = torch.jit.script(preserve_method_module) 201 m.eval() 202 opt_m = optimize_for_mobile(m) 203 no_preserveThis = getattr(opt_m, "preserveThis", None) 204 self.assertEqual(no_preserveThis, None) 205 opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"]) 206 preserveThis = getattr(opt_m, "preserveThis", None) 207 self.assertNotEqual(preserveThis, None) 208 209 class OptimizeNoForwardTest(torch.nn.Module): 210 def __init__(self) -> None: 211 super().__init__() 212 self.l = nn.Linear(10, 100) 213 self.l2 = nn.Linear(100, 1) 214 self.d = nn.Dropout(p=0.2) 215 216 @torch.jit.export 217 def foo(self, x): 218 x = self.d(F.relu(self.l(x))) 219 x = self.l2(x) 220 x = x + torch.ones(1, 100) 221 return F.relu(x) 222 input_data = torch.ones(1, 10) 223 m = torch.jit.script(OptimizeNoForwardTest()) 224 m.eval() 225 initial_result = m.foo(input_data) 226 227 optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo']) 228 optimized_result = optimized_scripted_model.foo(input_data) 229 230 FileCheck().check_not("dropout.__") \ 231 .check_count("aten::_add_relu(", 1, exactly=True) \ 232 .run(optimized_scripted_model.foo.graph) 233 torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3) 234 235 class BNTestNoForwardModule(torch.nn.Module): 236 def __init__(self) -> None: 237 super().__init__() 238 self.conv = torch.nn.Conv2d(1, 20, 5, 1) 239 self.bn = torch.nn.BatchNorm2d(num_features=20) 240 self.bn.eps = 0.0023 241 242 @torch.jit.export 243 def foo(self, x): 244 x = self.conv(x) 245 x = self.bn(x) 246 return x 247 248 bn_test_no_forward_module = BNTestNoForwardModule() 249 bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module) 250 bn_no_forward_scripted_module.eval() 251 252 self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11) 253 FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ 254 .run(bn_no_forward_scripted_module.foo.graph) 255 256 bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo']) 257 self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1) 258 bn_input = torch.rand(1, 1, 6, 6) 259 torch.testing.assert_close( 260 bn_no_forward_scripted_module.foo(bn_input), 261 bn_fold_no_forward_scripted_module.foo(bn_input), 262 rtol=1e-2, 263 atol=1e-3) 264 265 @skipIfNoXNNPACK 266 def test_quantized_conv_no_asan_failures(self): 267 # There were ASAN failures when fold_conv_bn was run on 268 # already quantized conv modules. Verifying that this does 269 # not happen again. 270 271 if 'qnnpack' not in torch.backends.quantized.supported_engines: 272 return 273 274 class Child(nn.Module): 275 def __init__(self) -> None: 276 super().__init__() 277 self.conv2 = nn.Conv2d(1, 1, 1) 278 279 def forward(self, x): 280 x = self.conv2(x) 281 return x 282 283 class Parent(nn.Module): 284 def __init__(self) -> None: 285 super().__init__() 286 self.quant = torch.ao.quantization.QuantStub() 287 self.conv1 = nn.Conv2d(1, 1, 1) 288 self.child = Child() 289 self.dequant = torch.ao.quantization.DeQuantStub() 290 291 def forward(self, x): 292 x = self.quant(x) 293 x = self.conv1(x) 294 x = self.child(x) 295 x = self.dequant(x) 296 return x 297 298 with override_quantized_engine('qnnpack'): 299 model = Parent() 300 model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') 301 torch.ao.quantization.prepare(model, inplace=True) 302 model(torch.randn(4, 1, 4, 4)) 303 torch.ao.quantization.convert(model, inplace=True) 304 model = torch.jit.script(model) 305 # this line should not have ASAN failures 306 model_optim = optimize_for_mobile(model) 307 308 def test_generate_mobile_module_lints(self): 309 class MyTestModule(torch.nn.Module): 310 def __init__(self) -> None: 311 super().__init__() 312 self.fc = torch.nn.Linear(4, 4) 313 self.dropout = torch.nn.Dropout(p=0.5) 314 315 def forward(self, inputs): 316 out = self.fc(inputs) 317 out = self.dropout(out) 318 return out 319 320 class MyBNModule(torch.nn.Module): 321 def __init__(self) -> None: 322 super().__init__() 323 self.bn = torch.nn.BatchNorm2d(4, affine=True) 324 325 def forward(self, inputs): 326 bn = self.bn(inputs) 327 return bn 328 329 class MyBundledInputModule(torch.nn.Module): 330 def forward(self, inputs): 331 return inputs 332 333 def get_lint_count_by_type(lint_type, module_lint_List): 334 return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name]) 335 336 test_module = torch.jit.script(MyTestModule()) 337 test_module_lint_list = generate_mobile_module_lints(test_module) 338 self.assertEqual(len(test_module_lint_list), 4) 339 self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1) 340 self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1) 341 self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2) 342 343 bn_module = torch.jit.script(MyBNModule()) 344 bn_module_lint_list = generate_mobile_module_lints(bn_module) 345 self.assertEqual(len(bn_module_lint_list), 4) 346 self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1) 347 self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1) 348 self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2) 349 350 bi_module = torch.jit.script(MyBundledInputModule()) 351 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 352 bi_module, [(torch.tensor([1]),)], []) 353 bi_module_lint_list = generate_mobile_module_lints(bi_module) 354 self.assertEqual(len(bi_module_lint_list), 0) 355 356 @skipIfNoXNNPACK 357 def test_preserve_bundled_inputs_methods(self): 358 class MyBundledInputModule(torch.nn.Module): 359 def forward(self, inputs): 360 return inputs 361 362 class MyIncompleteBundledInputModule(torch.nn.Module): 363 def forward(self, inputs): 364 return inputs 365 366 @torch.jit.export 367 def get_all_bundled_inputs(self): 368 pass 369 370 bi_module = torch.jit.script(MyBundledInputModule()) 371 module_optim_bi_not_preserved = optimize_for_mobile(bi_module) 372 373 # Expected to be False since no bundled inputs methods were added 374 self.assertFalse( 375 hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or 376 hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs') 377 ) 378 379 # Add bundled inputs methods to the module 380 torch.utils.bundled_inputs.augment_model_with_bundled_inputs( 381 bi_module, [(torch.tensor([1]),)], []) 382 # Now they should be preserved 383 module_optim_bi_preserved = optimize_for_mobile(bi_module) 384 385 # All of the bundled inputs methods were preserved 386 self.assertTrue( 387 hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and 388 hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs') 389 ) 390 391 bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0] 392 module_optim_bi_preserved(*bundled_input) 393 394 # If not all 3 bundled inputs methods are present in the module, 395 # we will not try to preserve them unless specified by the user. 396 incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule()) 397 incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module) 398 self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) 399 400 # Specifically preserve get_all_bundled_inputs even if it's the only one 401 # bundled inputs method available. 402 incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs']) 403 self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) 404 405 @skipIfNoXNNPACK 406 def test_hoist_conv_packed_params(self): 407 408 if 'qnnpack' not in torch.backends.quantized.supported_engines: 409 return 410 411 class Standalone(nn.Module): 412 def __init__(self) -> None: 413 super().__init__() 414 self.quant = torch.ao.quantization.QuantStub() 415 self.conv1 = nn.Conv2d(1, 1, 1) 416 self.conv2 = nn.Conv2d(1, 1, 1) 417 self.relu = nn.ReLU() 418 self.dequant = torch.ao.quantization.DeQuantStub() 419 420 def forward(self, x): 421 x = self.quant(x) 422 x = self.conv1(x) 423 x = self.conv2(x) 424 x = self.relu(x) 425 x = self.dequant(x) 426 return x 427 428 def fuse_model(self): 429 torch.ao.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True) 430 431 class Child(nn.Module): 432 def __init__(self) -> None: 433 super().__init__() 434 self.conv1 = nn.Conv2d(1, 1, 1) 435 436 def forward(self, x): 437 x = self.conv1(x) 438 return x 439 440 class Parent(nn.Module): 441 def __init__(self) -> None: 442 super().__init__() 443 self.quant = torch.ao.quantization.QuantStub() 444 self.conv1 = nn.Conv2d(1, 1, 1) 445 self.child = Child() 446 # TODO: test nn.Sequential after #42039 is fixed 447 self.dequant = torch.ao.quantization.DeQuantStub() 448 449 def forward(self, x): 450 x = self.quant(x) 451 x = self.conv1(x) 452 x = self.child(x) 453 x = self.dequant(x) 454 return x 455 456 def fuse_model(self): 457 pass 458 459 with override_quantized_engine('qnnpack'): 460 def _quant_script_and_optimize(model): 461 model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') 462 model.fuse_model() 463 torch.ao.quantization.prepare(model, inplace=True) 464 model(torch.randn(4, 1, 4, 4)) 465 torch.ao.quantization.convert(model, inplace=True) 466 model = torch.jit.script(model) 467 model_optim = optimize_for_mobile(model) 468 return model, model_optim 469 470 # basic case 471 472 m, m_optim = _quant_script_and_optimize(Standalone()) 473 FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ 474 .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ 475 .run(m_optim.graph) 476 self.assertFalse(hasattr(m_optim, "conv1")) 477 self.assertFalse(hasattr(m_optim, "conv2")) 478 479 data = torch.randn(4, 1, 4, 4) 480 m_res = m(data) 481 m_optim_res = m_optim(data) 482 torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3) 483 484 # generic case 485 486 m, m_optim = _quant_script_and_optimize(Parent()) 487 FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ 488 .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ 489 .run(m_optim.graph) 490 self.assertFalse(hasattr(m_optim, "conv1")) 491 self.assertFalse(hasattr(m_optim, "child")) 492 493 data = torch.randn(4, 1, 4, 4) 494 m_res = m(data) 495 m_optim_res = m_optim(data) 496 torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3) 497 498 @skipIfNoXNNPACK 499 @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision") 500 def test_mobilenet_optimize_for_mobile(self): 501 m = torchvision.models.mobilenet_v3_small() 502 m = torch.jit.script(m) 503 m = optimize_for_mobile(m) 504 505 # run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463 506 x = torch.zeros(1, 3, 56, 56) 507 self.assertEqual(m(x).numel(), 1000) 508 self.assertEqual(m(x).numel(), 1000) 509 self.assertEqual(m(x).numel(), 1000) 510 511 def test_clone_module_with_class(self): 512 class MyInnerTestModule(torch.nn.Module): 513 def __init__(self) -> None: 514 super().__init__() 515 self.pqr = torch.Tensor([10., 20., 30.]) 516 517 def forward(self, inputs): 518 return inputs 519 520 @torch.jit.export 521 def dummy_method_not_cloned(self): 522 return 20 523 524 class MyTestModule(torch.nn.Module): 525 def __init__(self) -> None: 526 super().__init__() 527 self.abc = 23 528 self.pqr = torch.Tensor([1., 2., 3.]) 529 self.inner = MyInnerTestModule() 530 531 def forward(self, inputs): 532 x = self.dummy_method_cloned() 533 # The call to self.inner.dummy_method_not_cloned should not raise an error 534 y = self.inner.dummy_method_not_cloned() 535 # The call to self.inner.pqr should not raise an error 536 z = self.inner.pqr 537 return (inputs, x, y, z) 538 539 @torch.jit.export 540 def dummy_method_not_cloned2(self): 541 # The call to self.inner.dummy_method_not_cloned should not raise an error 542 y = self.inner.dummy_method_not_cloned() 543 # The call to self.inner.pqr should not raise an error 544 z = self.inner.pqr 545 return self.pqr, self.dummy_method_not_cloned(), y, z 546 547 @torch.jit.export 548 def dummy_method_not_cloned(self): 549 return None 550 551 @torch.jit.export 552 def dummy_method_cloned(self): 553 return None 554 555 @torch.jit.export 556 def dummy_method_ref_attr_pqr(self): 557 return self.pqr, self.inner.pqr 558 559 m = torch.jit.script(MyTestModule()) 560 561 # Check that the methods exist on the original model. 562 self.assertEqual(hasattr(m, "dummy_method_not_cloned"), True) 563 self.assertEqual(hasattr(m, "dummy_method_cloned"), True) 564 self.assertEqual(hasattr(m, "dummy_method_not_cloned2"), True) 565 self.assertEqual(hasattr(m, "pqr"), True) 566 567 # Case-1: Successfully clone, ignoring 2 methods, keeping all attributes. 568 cloned = torch._C._hack_do_not_use_clone_module_with_class( 569 m._c, 570 ["dummy_method_not_cloned", "dummy_method_not_cloned2"], # ignored_methods 571 [], # ignored_attributes 572 ) 573 574 # Check that the ignored methods don't exist on the cloned model. 575 self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False) 576 self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True) 577 self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False) 578 self.assertEqual(hasattr(cloned, "pqr"), True) 579 580 # Check that the cloned class has a classname that starts with __torch__. 581 self.assertTrue( 582 cloned.qualified_name.startswith('__torch__.'), 583 ("Expected the cloned module's name to start with the string " 584 f"'__torch__.', but got: {cloned.qualified_name}"), 585 ) 586 587 588 # Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it. 589 cloned = torch._C._hack_do_not_use_clone_module_with_class( 590 m._c, 591 ["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"], 592 ["pqr"], 593 ) 594 595 # Check that the ignored methods don't exist on the cloned model. 596 self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False) 597 self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True) 598 self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False) 599 self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False) 600 self.assertEqual(hasattr(cloned, "pqr"), False) 601 602 603 # Case-3: The statement below will throw since dummy_method_cloned2 is preserved, 604 # and references dummy_method_not_cloned, which is not cloned. 605 with self.assertRaises(RuntimeError): 606 cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], []) 607 608 # Case-4: The statement below will throw since dummy_method_ref_attr_pqr 609 # is preserved, and references "pqr", which is not cloned. 610 with self.assertRaises(RuntimeError): 611 cloned = torch._C._hack_do_not_use_clone_module_with_class( 612 m._c, 613 ["dummy_method_not_cloned", "dummy_method_not_cloned2"], 614 ["pqr"], 615 ) 616 617 618if __name__ == '__main__': 619 run_tests() 620