1# Owner(s): ["module: unknown"] 2 3import copy 4import logging 5import warnings 6from typing import Tuple 7 8import torch 9from torch import nn 10from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler 11from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier 12from torch.testing._internal.common_utils import TestCase 13 14 15logging.basicConfig( 16 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO 17) 18 19 20class ImplementedDataScheduler(BaseDataScheduler): 21 def __init__(self, sparsifier, sparsifier_hyperparam, last_epoch=-1, verbose=False): 22 super().__init__(sparsifier, sparsifier_hyperparam, last_epoch, verbose) 23 24 def get_schedule_param(self): 25 if self.last_epoch > 0: 26 return { 27 name: config["sparsity_level"] * 0.5 28 for name, config in self.data_sparsifier.data_groups.items() 29 } 30 else: 31 return self.base_param 32 33 34class TestBaseDataScheduler(TestCase): 35 def _get_data(self): 36 tensor1, param1, emb1 = ( 37 torch.randn(5, 5), 38 nn.Parameter(torch.randn(10, 10)), 39 nn.Embedding(50, 5), 40 ) 41 data_list = [("tensor1", tensor1), ("param1", param1), ("emb1", emb1)] 42 defaults = { 43 "sparsity_level": 0.7, 44 "sparse_block_shape": (1, 4), 45 "zeros_per_block": 2, 46 } 47 data_with_config = [ 48 { 49 "name": "tensor2", 50 "data": torch.randn(4, 4), 51 "config": {"sparsity_level": 0.3}, 52 } 53 ] 54 return data_list, data_with_config, defaults 55 56 def _get_sparsifier(self, data_list, data_with_config, defaults): 57 sparsifier = DataNormSparsifier(data_list, **defaults) 58 for data_config_dict in data_with_config: 59 name, data, config = ( 60 data_config_dict["name"], 61 data_config_dict["data"], 62 data_config_dict["config"], 63 ) 64 sparsifier.add_data(name=name, data=data, **config) 65 return sparsifier 66 67 def _get_scheduler(self, sparsifier, schedule_param): 68 scheduler = ImplementedDataScheduler(sparsifier, schedule_param) 69 return scheduler 70 71 def _get_schedule_param(self): 72 return "sparsity_level" 73 74 def _get_name_data_config(self, some_data, defaults): 75 config = copy.deepcopy(defaults) 76 if isinstance(some_data, Tuple): 77 # dealing with data_list 78 name, data = some_data 79 else: 80 # dealing with data_with_config 81 name, data, new_config = ( 82 some_data["name"], 83 some_data["data"], 84 some_data["config"], 85 ) 86 config.update(new_config) 87 return name, data, config 88 89 def test_constructor(self): 90 """Checks if the warning is thrown if the scheduler step is called 91 before the sparsifier step""" 92 data_list, data_with_config, defaults = self._get_data() 93 sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) 94 schedule_param = self._get_schedule_param() 95 scheduler = self._get_scheduler(sparsifier, schedule_param) 96 97 assert scheduler.data_sparsifier == sparsifier 98 assert scheduler._step_count == 1 99 100 for name, config in sparsifier.data_groups.items(): 101 assert scheduler.base_param[name] == config.get(schedule_param, None) 102 103 def test_order_of_steps(self): 104 data_list, data_with_config, defaults = self._get_data() 105 sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) 106 schedule_param = self._get_schedule_param() 107 scheduler = self._get_scheduler(sparsifier, schedule_param) 108 109 # Sparsifier step is not called 110 with self.assertWarns(UserWarning): 111 scheduler.step() 112 113 # Correct order has no warnings 114 # Note: This will trigger if other warnings are present. 115 with warnings.catch_warnings(record=True) as w: 116 sparsifier.step() 117 scheduler.step() 118 # Make sure there is no warning related to the base_data_scheduler 119 for warning in w: 120 fname = warning.filename 121 fname = "/".join(fname.split("/")[-5:]) 122 assert ( 123 fname 124 != "torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py" 125 ) 126 127 def test_step(self): 128 data_list, data_with_config, defaults = self._get_data() 129 sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) 130 schedule_param = self._get_schedule_param() 131 scheduler = self._get_scheduler(sparsifier, schedule_param) 132 133 all_data = data_list + data_with_config 134 135 for some_data in all_data: 136 name, _, config = self._get_name_data_config(some_data, defaults) 137 assert ( 138 sparsifier.data_groups[name][schedule_param] == config[schedule_param] 139 ) 140 141 sparsifier.step() 142 scheduler.step() 143 144 for some_data in all_data: 145 name, _, config = self._get_name_data_config(some_data, defaults) 146 assert ( 147 sparsifier.data_groups[name][schedule_param] 148 == config[schedule_param] * 0.5 149 ) 150 151 # checking step count 152 step_cnt = 5 153 for _ in range(0, step_cnt): 154 sparsifier.step() 155 scheduler.step() 156 157 assert ( 158 scheduler._step_count == step_cnt + 2 159 ) # step_cnt + step above + 1 step in constructor 160 161 def test_state_dict(self): 162 data_list, data_with_config, defaults = self._get_data() 163 sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) 164 schedule_param = self._get_schedule_param() 165 scheduler1 = self._get_scheduler(sparsifier, schedule_param) 166 167 sparsifier.step() 168 scheduler1.step() 169 170 scheduler2 = self._get_scheduler(sparsifier, schedule_param) 171 all_data = data_list + data_with_config 172 for some_data in all_data: 173 name, _, _ = self._get_name_data_config(some_data, defaults) 174 assert scheduler1.base_param[name] != scheduler2.base_param[name] 175 assert scheduler1._last_param[name] == scheduler2.base_param[name] 176 177 scheduler1_state = scheduler1.state_dict() 178 scheduler2.load_state_dict(scheduler1_state) 179 180 for some_data in all_data: 181 name, _, _ = self._get_name_data_config(some_data, defaults) 182 assert scheduler1.base_param[name] == scheduler2.base_param[name] 183 assert scheduler1._last_param[name] == scheduler2._last_param[name] 184