1# Owner(s): ["module: unknown"] 2 3 4import logging 5 6import torch 7from torch.ao.pruning.sparsifier.utils import ( 8 fqn_to_module, 9 get_arg_info_from_tensor_fqn, 10 module_to_fqn, 11) 12from torch.testing._internal.common_quantization import ( 13 ConvBnReLUModel, 14 ConvModel, 15 FunctionalLinear, 16 LinearAddModel, 17 ManualEmbeddingBagLinear, 18 SingleLayerLinearModel, 19 TwoLayerLinearModel, 20) 21from torch.testing._internal.common_utils import TestCase 22 23 24logging.basicConfig( 25 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 26) 27 28model_list = [ 29 ConvModel, 30 SingleLayerLinearModel, 31 TwoLayerLinearModel, 32 LinearAddModel, 33 ConvBnReLUModel, 34 ManualEmbeddingBagLinear, 35 FunctionalLinear, 36] 37 38 39class TestSparsityUtilFunctions(TestCase): 40 def test_module_to_fqn(self): 41 """ 42 Tests that module_to_fqn works as expected when compared to known good 43 module.get_submodule(fqn) function 44 """ 45 for model_class in model_list: 46 model = model_class() 47 list_of_modules = [m for _, m in model.named_modules()] + [model] 48 for module in list_of_modules: 49 fqn = module_to_fqn(model, module) 50 check_module = model.get_submodule(fqn) 51 self.assertEqual(module, check_module) 52 53 def test_module_to_fqn_fail(self): 54 """ 55 Tests that module_to_fqn returns None when an fqn that doesn't 56 correspond to a path to a node/tensor is given 57 """ 58 for model_class in model_list: 59 model = model_class() 60 fqn = module_to_fqn(model, torch.nn.Linear(3, 3)) 61 self.assertEqual(fqn, None) 62 63 def test_module_to_fqn_root(self): 64 """ 65 Tests that module_to_fqn returns '' when model and target module are the same 66 """ 67 for model_class in model_list: 68 model = model_class() 69 fqn = module_to_fqn(model, model) 70 self.assertEqual(fqn, "") 71 72 def test_fqn_to_module(self): 73 """ 74 Tests that fqn_to_module operates as inverse 75 of module_to_fqn 76 """ 77 for model_class in model_list: 78 model = model_class() 79 list_of_modules = [m for _, m in model.named_modules()] + [model] 80 for module in list_of_modules: 81 fqn = module_to_fqn(model, module) 82 check_module = fqn_to_module(model, fqn) 83 self.assertEqual(module, check_module) 84 85 def test_fqn_to_module_fail(self): 86 """ 87 Tests that fqn_to_module returns None when it tries to 88 find an fqn of a module outside the model 89 """ 90 for model_class in model_list: 91 model = model_class() 92 fqn = "foo.bar.baz" 93 check_module = fqn_to_module(model, fqn) 94 self.assertEqual(check_module, None) 95 96 def test_fqn_to_module_for_tensors(self): 97 """ 98 Tests that fqn_to_module works for tensors, actually all parameters 99 of the model. This is tested by identifying a module with a tensor, 100 and generating the tensor_fqn using module_to_fqn on the module + 101 the name of the tensor. 102 """ 103 for model_class in model_list: 104 model = model_class() 105 list_of_modules = [m for _, m in model.named_modules()] + [model] 106 for module in list_of_modules: 107 module_fqn = module_to_fqn(model, module) 108 for tensor_name, tensor in module.named_parameters(recurse=False): 109 tensor_fqn = ( # string manip to handle tensors on root 110 module_fqn + ("." if module_fqn != "" else "") + tensor_name 111 ) 112 check_tensor = fqn_to_module(model, tensor_fqn) 113 self.assertEqual(tensor, check_tensor) 114 115 def test_get_arg_info_from_tensor_fqn(self): 116 """ 117 Tests that get_arg_info_from_tensor_fqn works for all parameters of the model. 118 Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and 119 then compares with known (parent) module and tensor_name as well as module_fqn 120 from module_to_fqn. 121 """ 122 for model_class in model_list: 123 model = model_class() 124 list_of_modules = [m for _, m in model.named_modules()] + [model] 125 for module in list_of_modules: 126 module_fqn = module_to_fqn(model, module) 127 for tensor_name, tensor in module.named_parameters(recurse=False): 128 tensor_fqn = ( 129 module_fqn + ("." if module_fqn != "" else "") + tensor_name 130 ) 131 arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) 132 self.assertEqual(arg_info["module"], module) 133 self.assertEqual(arg_info["module_fqn"], module_fqn) 134 self.assertEqual(arg_info["tensor_name"], tensor_name) 135 self.assertEqual(arg_info["tensor_fqn"], tensor_fqn) 136 137 def test_get_arg_info_from_tensor_fqn_fail(self): 138 """ 139 Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn 140 inputs. The string outputs still work but the output module is expected to be None. 141 """ 142 for model_class in model_list: 143 model = model_class() 144 tensor_fqn = "foo.bar.baz" 145 arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) 146 self.assertEqual(arg_info["module"], None) 147 self.assertEqual(arg_info["module_fqn"], "foo.bar") 148 self.assertEqual(arg_info["tensor_name"], "baz") 149 self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz") 150