xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_sparsity_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3
4import logging
5
6import torch
7from torch.ao.pruning.sparsifier.utils import (
8    fqn_to_module,
9    get_arg_info_from_tensor_fqn,
10    module_to_fqn,
11)
12from torch.testing._internal.common_quantization import (
13    ConvBnReLUModel,
14    ConvModel,
15    FunctionalLinear,
16    LinearAddModel,
17    ManualEmbeddingBagLinear,
18    SingleLayerLinearModel,
19    TwoLayerLinearModel,
20)
21from torch.testing._internal.common_utils import TestCase
22
23
24logging.basicConfig(
25    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
26)
27
28model_list = [
29    ConvModel,
30    SingleLayerLinearModel,
31    TwoLayerLinearModel,
32    LinearAddModel,
33    ConvBnReLUModel,
34    ManualEmbeddingBagLinear,
35    FunctionalLinear,
36]
37
38
39class TestSparsityUtilFunctions(TestCase):
40    def test_module_to_fqn(self):
41        """
42        Tests that module_to_fqn works as expected when compared to known good
43        module.get_submodule(fqn) function
44        """
45        for model_class in model_list:
46            model = model_class()
47            list_of_modules = [m for _, m in model.named_modules()] + [model]
48            for module in list_of_modules:
49                fqn = module_to_fqn(model, module)
50                check_module = model.get_submodule(fqn)
51                self.assertEqual(module, check_module)
52
53    def test_module_to_fqn_fail(self):
54        """
55        Tests that module_to_fqn returns None when an fqn that doesn't
56        correspond to a path to a node/tensor is given
57        """
58        for model_class in model_list:
59            model = model_class()
60            fqn = module_to_fqn(model, torch.nn.Linear(3, 3))
61            self.assertEqual(fqn, None)
62
63    def test_module_to_fqn_root(self):
64        """
65        Tests that module_to_fqn returns '' when model and target module are the same
66        """
67        for model_class in model_list:
68            model = model_class()
69            fqn = module_to_fqn(model, model)
70            self.assertEqual(fqn, "")
71
72    def test_fqn_to_module(self):
73        """
74        Tests that fqn_to_module operates as inverse
75        of module_to_fqn
76        """
77        for model_class in model_list:
78            model = model_class()
79            list_of_modules = [m for _, m in model.named_modules()] + [model]
80            for module in list_of_modules:
81                fqn = module_to_fqn(model, module)
82                check_module = fqn_to_module(model, fqn)
83                self.assertEqual(module, check_module)
84
85    def test_fqn_to_module_fail(self):
86        """
87        Tests that fqn_to_module returns None when it tries to
88        find an fqn of a module outside the model
89        """
90        for model_class in model_list:
91            model = model_class()
92            fqn = "foo.bar.baz"
93            check_module = fqn_to_module(model, fqn)
94            self.assertEqual(check_module, None)
95
96    def test_fqn_to_module_for_tensors(self):
97        """
98        Tests that fqn_to_module works for tensors, actually all parameters
99        of the model. This is tested by identifying a module with a tensor,
100        and generating the tensor_fqn using module_to_fqn on the module +
101        the name of the tensor.
102        """
103        for model_class in model_list:
104            model = model_class()
105            list_of_modules = [m for _, m in model.named_modules()] + [model]
106            for module in list_of_modules:
107                module_fqn = module_to_fqn(model, module)
108                for tensor_name, tensor in module.named_parameters(recurse=False):
109                    tensor_fqn = (  # string manip to handle tensors on root
110                        module_fqn + ("." if module_fqn != "" else "") + tensor_name
111                    )
112                    check_tensor = fqn_to_module(model, tensor_fqn)
113                    self.assertEqual(tensor, check_tensor)
114
115    def test_get_arg_info_from_tensor_fqn(self):
116        """
117        Tests that get_arg_info_from_tensor_fqn works for all parameters of the model.
118        Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and
119        then compares with known (parent) module and tensor_name as well as module_fqn
120        from module_to_fqn.
121        """
122        for model_class in model_list:
123            model = model_class()
124            list_of_modules = [m for _, m in model.named_modules()] + [model]
125            for module in list_of_modules:
126                module_fqn = module_to_fqn(model, module)
127                for tensor_name, tensor in module.named_parameters(recurse=False):
128                    tensor_fqn = (
129                        module_fqn + ("." if module_fqn != "" else "") + tensor_name
130                    )
131                    arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
132                    self.assertEqual(arg_info["module"], module)
133                    self.assertEqual(arg_info["module_fqn"], module_fqn)
134                    self.assertEqual(arg_info["tensor_name"], tensor_name)
135                    self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)
136
137    def test_get_arg_info_from_tensor_fqn_fail(self):
138        """
139        Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn
140        inputs. The string outputs still work but the output module is expected to be None.
141        """
142        for model_class in model_list:
143            model = model_class()
144            tensor_fqn = "foo.bar.baz"
145            arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
146            self.assertEqual(arg_info["module"], None)
147            self.assertEqual(arg_info["module_fqn"], "foo.bar")
148            self.assertEqual(arg_info["tensor_name"], "baz")
149            self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
150