xref: /aosp_15_r20/external/pytorch/test/ao/sparsity/test_data_scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import copy
4import logging
5import warnings
6from typing import Tuple
7
8import torch
9from torch import nn
10from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
11from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
12from torch.testing._internal.common_utils import TestCase
13
14
15logging.basicConfig(
16    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
17)
18
19
20class ImplementedDataScheduler(BaseDataScheduler):
21    def __init__(self, sparsifier, sparsifier_hyperparam, last_epoch=-1, verbose=False):
22        super().__init__(sparsifier, sparsifier_hyperparam, last_epoch, verbose)
23
24    def get_schedule_param(self):
25        if self.last_epoch > 0:
26            return {
27                name: config["sparsity_level"] * 0.5
28                for name, config in self.data_sparsifier.data_groups.items()
29            }
30        else:
31            return self.base_param
32
33
34class TestBaseDataScheduler(TestCase):
35    def _get_data(self):
36        tensor1, param1, emb1 = (
37            torch.randn(5, 5),
38            nn.Parameter(torch.randn(10, 10)),
39            nn.Embedding(50, 5),
40        )
41        data_list = [("tensor1", tensor1), ("param1", param1), ("emb1", emb1)]
42        defaults = {
43            "sparsity_level": 0.7,
44            "sparse_block_shape": (1, 4),
45            "zeros_per_block": 2,
46        }
47        data_with_config = [
48            {
49                "name": "tensor2",
50                "data": torch.randn(4, 4),
51                "config": {"sparsity_level": 0.3},
52            }
53        ]
54        return data_list, data_with_config, defaults
55
56    def _get_sparsifier(self, data_list, data_with_config, defaults):
57        sparsifier = DataNormSparsifier(data_list, **defaults)
58        for data_config_dict in data_with_config:
59            name, data, config = (
60                data_config_dict["name"],
61                data_config_dict["data"],
62                data_config_dict["config"],
63            )
64            sparsifier.add_data(name=name, data=data, **config)
65        return sparsifier
66
67    def _get_scheduler(self, sparsifier, schedule_param):
68        scheduler = ImplementedDataScheduler(sparsifier, schedule_param)
69        return scheduler
70
71    def _get_schedule_param(self):
72        return "sparsity_level"
73
74    def _get_name_data_config(self, some_data, defaults):
75        config = copy.deepcopy(defaults)
76        if isinstance(some_data, Tuple):
77            # dealing with data_list
78            name, data = some_data
79        else:
80            # dealing with data_with_config
81            name, data, new_config = (
82                some_data["name"],
83                some_data["data"],
84                some_data["config"],
85            )
86            config.update(new_config)
87        return name, data, config
88
89    def test_constructor(self):
90        """Checks if the warning is thrown if the scheduler step is called
91        before the sparsifier step"""
92        data_list, data_with_config, defaults = self._get_data()
93        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
94        schedule_param = self._get_schedule_param()
95        scheduler = self._get_scheduler(sparsifier, schedule_param)
96
97        assert scheduler.data_sparsifier == sparsifier
98        assert scheduler._step_count == 1
99
100        for name, config in sparsifier.data_groups.items():
101            assert scheduler.base_param[name] == config.get(schedule_param, None)
102
103    def test_order_of_steps(self):
104        data_list, data_with_config, defaults = self._get_data()
105        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
106        schedule_param = self._get_schedule_param()
107        scheduler = self._get_scheduler(sparsifier, schedule_param)
108
109        # Sparsifier step is not called
110        with self.assertWarns(UserWarning):
111            scheduler.step()
112
113        # Correct order has no warnings
114        # Note: This will trigger if other warnings are present.
115        with warnings.catch_warnings(record=True) as w:
116            sparsifier.step()
117            scheduler.step()
118            # Make sure there is no warning related to the base_data_scheduler
119            for warning in w:
120                fname = warning.filename
121                fname = "/".join(fname.split("/")[-5:])
122                assert (
123                    fname
124                    != "torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py"
125                )
126
127    def test_step(self):
128        data_list, data_with_config, defaults = self._get_data()
129        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
130        schedule_param = self._get_schedule_param()
131        scheduler = self._get_scheduler(sparsifier, schedule_param)
132
133        all_data = data_list + data_with_config
134
135        for some_data in all_data:
136            name, _, config = self._get_name_data_config(some_data, defaults)
137            assert (
138                sparsifier.data_groups[name][schedule_param] == config[schedule_param]
139            )
140
141        sparsifier.step()
142        scheduler.step()
143
144        for some_data in all_data:
145            name, _, config = self._get_name_data_config(some_data, defaults)
146            assert (
147                sparsifier.data_groups[name][schedule_param]
148                == config[schedule_param] * 0.5
149            )
150
151        # checking step count
152        step_cnt = 5
153        for _ in range(0, step_cnt):
154            sparsifier.step()
155            scheduler.step()
156
157        assert (
158            scheduler._step_count == step_cnt + 2
159        )  # step_cnt + step above + 1 step in constructor
160
161    def test_state_dict(self):
162        data_list, data_with_config, defaults = self._get_data()
163        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
164        schedule_param = self._get_schedule_param()
165        scheduler1 = self._get_scheduler(sparsifier, schedule_param)
166
167        sparsifier.step()
168        scheduler1.step()
169
170        scheduler2 = self._get_scheduler(sparsifier, schedule_param)
171        all_data = data_list + data_with_config
172        for some_data in all_data:
173            name, _, _ = self._get_name_data_config(some_data, defaults)
174            assert scheduler1.base_param[name] != scheduler2.base_param[name]
175            assert scheduler1._last_param[name] == scheduler2.base_param[name]
176
177        scheduler1_state = scheduler1.state_dict()
178        scheduler2.load_state_dict(scheduler1_state)
179
180        for some_data in all_data:
181            name, _, _ = self._get_name_data_config(some_data, defaults)
182            assert scheduler1.base_param[name] == scheduler2.base_param[name]
183            assert scheduler1._last_param[name] == scheduler2._last_param[name]
184